Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add KServe gRPC v2 support #2176

Merged
merged 16 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions kubernetes/kserve/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# syntax = docker/dockerfile:experimental
#
# Following comments have been shamelessly copied from https://github.com/pytorch/pytorch/blob/master/Dockerfile
#
#
# NOTE: To build this you will need a docker version > 18.06 with
# experimental enabled and DOCKER_BUILDKIT=1
#
# If you do not use buildkit you are not going to have a good time
#
# For reference:
# For reference:
# https://docs.docker.com/develop/develop-images/build_enhancements

ARG BASE_IMAGE=pytorch/torchserve:latest
Expand All @@ -24,9 +24,18 @@ RUN pip install -r requirements.txt
COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh
RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh
COPY kserve_wrapper kserve_wrapper

COPY ./*.proto ./kserve_wrapper/

RUN python -m grpc_tools.protoc \
--proto_path=./kserve_wrapper \
--python_out=./kserve_wrapper \
--grpc_python_out=./kserve_wrapper \
./kserve_wrapper/inference.proto \
./kserve_wrapper/management.proto

COPY config.properties config.properties

USER model-server

ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"]

2 changes: 2 additions & 0 deletions kubernetes/kserve/Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ RUN if [ "$MACHINE_TYPE" = "gpu" ]; then export USE_CUDA=1; fi \
&& chmod +x /usr/local/bin/dockerd-entrypoint.sh \
&& chown -R model-server /home/model-server \
&& cp -R kubernetes/kserve/kserve_wrapper /home/model-server/kserve_wrapper \
&& cp frontend/server/src/main/resources/proto/*.proto /home/model-serve/kserve_wrapper \
&& python -m grpc_tools.protoc --proto_path=/home/model-server/kserve_wrapper --python_out=/home/model-server/kserve_wrapper --grpc_python_out=/home/model-server/kserve_wrapper /home/model-server/kserve_wrapper/inference.proto /home/model-server/kserve_wrapper/management.proto \
&& cp kubernetes/kserve/config.properties /home/model-server/config.properties \
&& mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store

Expand Down
4 changes: 2 additions & 2 deletions kubernetes/kserve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ Currently, KServe supports the Inference API for all the existing models but tex
./build_image.sh -g -t <repository>/<image>:<tag>
```

### Docker Image Dev Build
- To create dev image

```bash
DOCKER_BUILDKIT=1 docker build -f Dockerfile.dev -t pytorch/torchserve-kfs:latest-dev .
./build_image.sh -g -d -t <repository>/<image>:<tag>
```

## Running Torchserve inference service in KServe cluster
Expand Down
9 changes: 8 additions & 1 deletion kubernetes/kserve/build_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

DOCKER_TAG="pytorch/torchserve-kfs:latest"
BASE_IMAGE="pytorch/torchserve:latest"
DOCKER_FILE="Dockerfile"

for arg in "$@"
do
Expand All @@ -18,6 +19,10 @@ do
BASE_IMAGE="pytorch/torchserve:latest-gpu"
shift
;;
-d|--dev)
DOCKER_FILE="Dockerfile.dev"
shift
;;
-t|--tag)
DOCKER_TAG="$2"
shift
Expand All @@ -26,4 +31,6 @@ do
esac
done

DOCKER_BUILDKIT=1 docker build --file Dockerfile --build-arg BASE_IMAGE=$BASE_IMAGE -t "$DOCKER_TAG" .
cp ../../frontend/server/src/main/resources/proto/*.proto .

DOCKER_BUILDKIT=1 docker build --file "$DOCKER_FILE" --build-arg BASE_IMAGE=$BASE_IMAGE -t "$DOCKER_TAG" .
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
{
"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298",
"inputs": [
{
"data": ["iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII="],
"datatype": "BYTES",
"name": "e8d5afed-0a56-4deb-ac9c-352663f51b93",
"name": "312a4eb0-0ca7-4803-a101-a6d2c18486fe",
"shape": [-1]
}
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"model_name": "mnist",
"inputs": [{
"name": "312a4eb0-0ca7-4803-a101-a6d2c18486fe",
"shape": [-1],
"datatype": "BYTES",
"contents": {
"bytes_contents": ["iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII="]
}
}]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298",
"model_name": "mnist",
"inputs": [{
"name": "input-0",
"shape": [1, 28, 28],
"datatype": "FP32",
"contents": {
"fp32_contents": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23919999599456787, 0.011800000444054604, 0.1647000014781952, 0.4627000093460083, 0.7569000124931335, 0.4627000093460083, 0.4627000093460083, 0.23919999599456787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05490000173449516, 0.7020000219345093, 0.9607999920845032, 0.9254999756813049, 0.9490000009536743, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9607999920845032, 0.9215999841690063, 0.3294000029563904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.592199981212616, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.8353000283241272, 0.7529000043869019, 0.6980000138282776, 0.6980000138282776, 0.7059000134468079, 0.9961000084877014, 0.9961000084877014, 0.9451000094413757, 0.18039999902248383, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16859999299049377, 0.9215999841690063, 0.9961000084877014, 0.8863000273704529, 0.25099998712539673, 0.10980000346899033, 0.0471000000834465, 0.0, 0.0, 0.007799999788403511, 0.5019999742507935, 0.9882000088691711, 1.0, 0.6783999800682068, 0.06669999659061432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.21960000693798065, 0.9961000084877014, 0.9922000169754028, 0.4196000099182129, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5254999995231628, 0.980400025844574, 0.9961000084877014, 0.29409998655319214, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.24709999561309814, 0.9961000084877014, 0.6195999979972839, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8666999936103821, 0.9961000084877014, 0.6157000064849854, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7608000040054321, 0.9961000084877014, 0.40389999747276306, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5881999731063843, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13330000638961792, 0.8626999855041504, 0.9373000264167786, 0.22750000655651093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49410000443458557, 0.9961000084877014, 0.6705999970436096, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.9373000264167786, 0.2353000044822693, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04309999942779541, 0.8587999939918518, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38429999351501465, 0.9961000084877014, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6352999806404114, 0.9961000084877014, 0.819599986076355, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38429999351501465, 0.9961000084877014, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20000000298023224, 0.9333000183105469, 0.9961000084877014, 0.29409998655319214, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38429999351501465, 0.9961000084877014, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20000000298023224, 0.6470999717712402, 0.9961000084877014, 0.7646999955177307, 0.015699999406933784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2587999999523163, 0.9451000094413757, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011800000444054604, 0.6549000144004822, 0.9961000084877014, 0.8902000188827515, 0.21570000052452087, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.8353000283241272, 0.07840000092983246, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18039999902248383, 0.5960999727249146, 0.7922000288963318, 0.9961000084877014, 0.9961000084877014, 0.24709999561309814, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.9961000084877014, 0.800000011920929, 0.7059000134468079, 0.7059000134468079, 0.7059000134468079, 0.7059000134468079, 0.7059000134468079, 0.9215999841690063, 0.9961000084877014, 0.9961000084877014, 0.9175999760627747, 0.6118000149726868, 0.03920000046491623, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3176000118255615, 0.8039000034332275, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9882000088691711, 0.9175999760627747, 0.4706000089645386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10199999809265137, 0.8234999775886536, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.6000000238418579, 0.40779998898506165, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
}
}]
}
59 changes: 35 additions & 24 deletions kubernetes/kserve/kserve_wrapper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Follow the below steps to serve the MNIST Model :
- Step 2 : Install KServe as below:

```bash
pip install kserve>=0.9.0
pip install kserve>=0.9.0 grpcio protobuf grpcio-tools
```

- Step 4 : Run the Install Dependencies script
Expand Down Expand Up @@ -59,11 +59,11 @@ sudo mkdir -p /mnt/models/model-store

For v1 protocol

``export TS_SERVICE_ENVELOPE=kserve`
`export TS_SERVICE_ENVELOPE=kserve`

For v2 protocol

``export TS_SERVICE_ENVELOPE=kservev2`
`export TS_SERVICE_ENVELOPE=kservev2`

- Step 10: Move the config.properties to /mnt/models/config/.
The config.properties file is as below :
Expand Down Expand Up @@ -93,6 +93,26 @@ torchserve --start --ts-config /mnt/models/config/config.properties

- Step 12: Run the below command to start the KFServer

- Step 13: Set protocol version

For v1 protocol

`export PROTOCOL_VERSION=v1`

For v2 protocol

`export PROTOCOL_VERSION=v2`

For grpc protocol v2 format set

`export PROTOCOL_VERSION=grpc-v2`

- Generate python gRPC client stub using the proto files

```bash
python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
```

```bash
python3 serve/kubernetes/kserve/kserve_wrapper/__main__.py
```
Expand Down Expand Up @@ -127,7 +147,7 @@ Output:

The curl request for explain is as below:

```
```bash
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/v1/mnist.json http://0.0.0.0:8080/v1/models/mnist:explain
```

Expand All @@ -146,7 +166,7 @@ For v2 protocol
The curl request for inference is as below:

```bash
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/mnist_v2.json http://0.0.0.0:8080/v2/models/mnist/infer
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_tensor.json http://0.0.0.0:8080/v2/models/mnist/infer
```

Response:
Expand All @@ -167,29 +187,20 @@ Response:
}
```

The curl request for explain is as below:
For grpc-v2 protocol

```
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/v1/mnist.json http://0.0.0.0:8080/v2/models/mnist/explain
- Download the proto file

```bash
curl -O https://raw.githubusercontent.com/kserve/kserve/master/docs/predict-api/v2/grpc_predict_v2.proto
```

Response:
- Download [grpcurl](https://github.com/fullstorydev/grpcurl)

```json
{
"id": "3482b766-0483-40e9-84b0-8ce8d4d1576e",
"model_name": "mnist",
"model_version": "1.0",
"outputs": [{
"name": "explain",
"shape": [1, 28, 28],
"datatype": "FP64",
"data": [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0, 0.0
...
...
]
}]
}
Make gRPC request

```bash
grpcurl -vv -plaintext -proto grpc_predict_v2.proto -d @ localhost:8081 inference.GRPCInferenceService.ModelInfer <<< $(cat "serve/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_tensor_grpc.json")
```

## KServe Wrapper Testing in Local for BERT
Expand Down
7 changes: 6 additions & 1 deletion kubernetes/kserve/kserve_wrapper/TSModelRepository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ class TSModelRepository(ModelRepository):
as inputs to the TSModel Repository.
"""

def __init__(self, inference_address: str, management_address: str, model_dir: str):
def __init__(
self,
inference_address: str,
management_address: str,
model_dir: str,
):
"""The Inference Address, Management Address and the Model Directory from the kserve
side is initialized here.

Expand Down
95 changes: 92 additions & 3 deletions kubernetes/kserve/kserve_wrapper/TorchserveModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,37 @@
return a KServe side response """
import logging
import pathlib
from enum import Enum
from typing import Dict, Union

import grpc
import inference_pb2_grpc
import kserve
from gprc_utils import from_ts_grpc, to_ts_grpc
from inference_pb2 import PredictionResponse
from kserve.errors import ModelMissingError
from kserve.model import Model as Model
from kserve.protocol.grpc.grpc_predict_v2_pb2 import (
ModelInferRequest,
ModelInferResponse,
)
from kserve.protocol.infer_type import InferRequest, InferResponse
from kserve.storage import Storage

logging.basicConfig(level=kserve.constants.KSERVE_LOGLEVEL)

PREDICTOR_URL_FORMAT = PREDICTOR_V2_URL_FORMAT = "http://{0}/predictions/{1}"
EXPLAINER_URL_FORMAT = EXPLAINER_V2_URL_FORMAT = "http://{0}/explanations/{1}"
EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}"
REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}"
UNREGISTER_URL_FORMAT = "{0}/models/{1}"


class PredictorProtocol(Enum):
REST_V1 = "v1"
REST_V2 = "v2"
GRPC_V2 = "grpc-v2"


class TorchserveModel(Model):
"""The torchserve side inference and explain end-points requests are handled to
return a KServe side response
Expand All @@ -25,7 +42,15 @@ class TorchserveModel(Model):
side predict and explain http requests.
"""

def __init__(self, name, inference_address, management_address, model_dir):
def __init__(
self,
name,
inference_address,
management_address,
grpc_inference_address,
protocol,
model_dir,
):
"""The Model Name, Inference Address, Management Address and the model directory
are specified.

Expand All @@ -45,10 +70,74 @@ def __init__(self, name, inference_address, management_address, model_dir):
self.inference_address = inference_address
self.management_address = management_address
self.model_dir = model_dir
self.protocol = protocol

if self.protocol == PredictorProtocol.GRPC_V2.value:
self.predictor_host = grpc_inference_address

logging.info("Predict URL set to %s", self.predictor_host)
self.explainer_host = self.predictor_host
logging.info("Explain URL set to %s", self.explainer_host)
logging.info("Protocol version is %s", self.protocol)

def grpc_client(self):
if self._grpc_client_stub is None:
self.channel = grpc.aio.insecure_channel(self.predictor_host)
self.grpc_client_stub = inference_pb2_grpc.InferenceAPIsServiceStub(
self.channel
)
return self.grpc_client_stub

async def _grpc_predict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this function is called. Could you add some comments (eg. which function call this one)?

self,
payload: Union[ModelInferRequest, InferRequest],
headers: Dict[str, str] = None,
) -> ModelInferResponse:
"""Overrides the `_grpc_predict` method in Model class. The predict method calls
the `_grpc_predict` method if the self.protocol is "grpc_v2"

Args:
request (Dict|InferRequest|ModelInferRequest): The response passed from ``predict`` handler.

Returns:
Dict: Torchserve grpc response.
"""
payload = to_ts_grpc(payload)
grpc_stub = self.grpc_client()
async_result = await grpc_stub.Predictions(payload)
return async_result

def postprocess(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this function is called. Could you add some comments (eg. which function call this one)?

self,
response: Union[Dict, InferResponse, ModelInferResponse, PredictionResponse],
headers: Dict[str, str] = None,
) -> Union[Dict, ModelInferResponse]:
"""This method converts the v2 infer response types to gRPC or REST.
For gRPC request it converts InferResponse to gRPC message or directly returns ModelInferResponse from
predictor call or converts TS PredictionResponse to ModelInferResponse.
For REST request it converts ModelInferResponse to Dict or directly returns from predictor call.

Args:
response (Dict|InferResponse|ModelInferResponse|PredictionResponse): The response passed from ``predict`` handler.
headers (Dict): Request headers.

Returns:
Dict: post-processed response.
"""
if headers:
if "grpc" in headers.get("user-agent", ""):
if isinstance(response, ModelInferResponse):
return response
elif isinstance(response, InferResponse):
return response.to_grpc()
elif isinstance(response, PredictionResponse):
return from_ts_grpc(response)
if "application/json" in headers.get("content-type", ""):
# If the original request is REST, convert the gRPC predict response to dict
if isinstance(response, ModelInferResponse):
return InferResponse.from_grpc(response).to_rest()
elif isinstance(response, InferResponse):
return response.to_rest()
return response

def load(self) -> bool:
"""This method validates model availabilty in the model directory
Expand Down
Loading
Loading