Skip to content

Commit

Permalink
[Tool] Add a tool to test TorchServe. (#468)
Browse files Browse the repository at this point in the history
* Add `title` option in `show_result_pyplot`.

* Add test_torchserver.py

* Add docs about test torchserve

* Update docs and result output.

* Update chinese docs.
  • Loading branch information
mzr1996 authored Oct 14, 2021
1 parent fd0f5cc commit 10e8495
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 17 deletions.
46 changes: 39 additions & 7 deletions docs/tools/model_serving.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ python tools/deployment/mmcls2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
--model-name ${MODEL_NAME}
```

**Note**: ${MODEL_STORE} needs to be an absolute path to a folder.
```{note}
${MODEL_STORE} needs to be an absolute path to a folder.
```

Example:

```shell
python tools/deployment/mmcls2torchserve.py \
configs/resnet/resnet18_b32x8_imagenet.py \
checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth \
--output-folder ./checkpoints \
--model-name resnet18_in1k
```

## 2. Build `mmcls-serve` docker image

Expand All @@ -31,25 +43,45 @@ docker run --rm \
--cpus 8 \
--gpus device=0 \
-p8080:8080 -p8081:8081 -p8082:8082 \
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
--mount type=bind,source=`realpath ./checkpoints`,target=/home/model-server/model-store \
mmcls-serve:latest
```

```{note}
`realpath ./checkpoints` points to the absolute path of "./checkpoints", and you can replace it with the absolute path where you store torchserve models.
```

[Read the docs](https://github.com/pytorch/serve/blob/master/docs/rest_api.md) about the Inference (8080), Management (8081) and Metrics (8082) APis

## 4. Test deployment

```shell
curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/3dogs.jpg
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T 3dogs.jpg
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T demo/demo.JPEG
```

You should obtain a response similar to:

```json
{
"pred_label": 245,
"pred_score": 0.5536593794822693,
"pred_class": "French bulldog"
"pred_label": 58,
"pred_score": 0.38102269172668457,
"pred_class": "water snake"
}
```

And you can use `test_torchserver.py` to compare result of TorchServe and PyTorch, and visualize them.

```shell
python tools/deployment/test_torchserver.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${MODEL_NAME}
[--inference-addr ${INFERENCE_ADDR}] [--device ${DEVICE}]
```

Example:

```shell
python tools/deployment/test_torchserver.py \
demo/demo.JPEG \
configs/resnet/resnet18_b32x8_imagenet.py \
checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth \
resnet18_in1k
```
46 changes: 39 additions & 7 deletions docs_zh-CN/tools/model_serving.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ python tools/deployment/mmcls2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
--model-name ${MODEL_NAME}
```

**注意**: ${MODEL_STORE} 需要是一个文件夹的绝对路径。
```{note}
${MODEL_STORE} 需要是一个文件夹的绝对路径。
```

示例:

```shell
python tools/deployment/mmcls2torchserve.py \
configs/resnet/resnet18_b32x8_imagenet.py \
checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth \
--output-folder ./checkpoints \
--model-name resnet18_in1k
```

## 2. 构建 `mmcls-serve` docker 镜像

Expand All @@ -31,25 +43,45 @@ docker run --rm \
--cpus 8 \
--gpus device=0 \
-p8080:8080 -p8081:8081 -p8082:8082 \
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
--mount type=bind,source=`realpath ./checkpoints`,target=/home/model-server/model-store \
mmcls-serve:latest
```

```{note}
`realpath ./checkpoints` 是 "./checkpoints" 的绝对路径,你可以将其替换为你保存 TorchServe 模型的目录的绝对路径。
```

参考 [该文档](https://github.com/pytorch/serve/blob/master/docs/rest_api.md) 了解关于推理 (8080),管理 (8081) 和指标 (8082) 等 API 的信息。

## 4. 测试部署

```shell
curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/3dogs.jpg
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T 3dogs.jpg
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T demo/demo.JPEG
```

您应该获得类似于以下内容的响应:

```json
{
"pred_label": 245,
"pred_score": 0.5536593794822693,
"pred_class": "French bulldog"
"pred_label": 58,
"pred_score": 0.38102269172668457,
"pred_class": "water snake"
}
```

另外,你也可以使用 `test_torchserver.py` 来比较 TorchServe 和 PyTorch 的结果,并进行可视化。

```shell
python tools/deployment/test_torchserver.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${MODEL_NAME}
[--inference-addr ${INFERENCE_ADDR}] [--device ${DEVICE}]
```

示例:

```shell
python tools/deployment/test_torchserver.py \
demo/demo.JPEG \
configs/resnet/resnet18_b32x8_imagenet.py \
checkpoints/resnet18_8xb32_in1k_20210831-fbbb1da6.pth \
resnet18_in1k
```
16 changes: 14 additions & 2 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def inference_model(model, img):
return result


def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
def show_result_pyplot(model,
img,
result,
fig_size=(15, 10),
title='result',
wait_time=0):
"""Visualize the classification results on the image.
Args:
Expand All @@ -98,10 +103,17 @@ def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
result (list): The classification result.
fig_size (tuple): Figure size of the pyplot figure.
Defaults to (15, 10).
title (str): Title of the pyplot figure.
Defaults to 'result'.
wait_time (int): How many seconds to display the image.
Defaults to 0.
"""
if hasattr(model, 'module'):
model = model.module
model.show_result(
img, result, show=True, fig_size=fig_size, wait_time=wait_time)
img,
result,
show=True,
fig_size=fig_size,
win_name=title,
wait_time=wait_time)
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
albumentations>=0.3.2 --no-binary imgaug,albumentations
requests
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,m2r,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,recommonmark,rich,seaborn,sphinx,torch,torchvision,ts
known_third_party = PIL,m2r,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,recommonmark,requests,rich,seaborn,sphinx,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

Expand Down
44 changes: 44 additions & 0 deletions tools/deployment/test_torchserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from argparse import ArgumentParser

import numpy as np
import requests

from mmcls.apis import inference_model, init_model, show_result_pyplot


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args


def main(args):
# Inference single image by native apis.
model = init_model(args.config, args.checkpoint, device=args.device)
model_result = inference_model(model, args.img)
show_result_pyplot(model, args.img, model_result, title='pytorch_result')

# Inference single image by torchserve engine.
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
with open(args.img, 'rb') as image:
response = requests.post(url, image)
server_result = response.json()
show_result_pyplot(model, args.img, server_result, title='server_result')

assert np.allclose(model_result['pred_score'], server_result['pred_score'])
print('Test complete, the results of PyTorch and TorchServe are the same.')


if __name__ == '__main__':
args = parse_args()
main(args)

0 comments on commit 10e8495

Please sign in to comment.