Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move dali preprocess to handler util #2485

Merged
merged 24 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions examples/nvidia_dali/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,91 @@

The NVIDIA Data Loading Library (DALI) is a library for data loading and pre-processing to accelerate deep learning applications. It provides a collection of highly optimized building blocks for loading and processing image, video and audio data.

In this example, we use NVIDIA DALI for pre-processing image input for inference in resnet-18 model.
In this example, we use NVIDIA DALI for pre-processing image input for inference in resnet-18 and mnist models.

Refer to [NVIDIA-DALI-Documentation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html) for detailed information

jagadeeshi2i marked this conversation as resolved.
Show resolved Hide resolved
### Install dependencies
## Install dependencies

Navigate to `serve/examples/nvidia_dali` directory and run the below command to install the dependencies

```bash
pip install -r requirements.txt
```

## DALI Pre-Processing with Default dali_image_classifier handler for resnet-18 model

### Create model-archive file

Navigate to `serve` folder

Download model weights

```bash
wget https://download.pytorch.org/models/resnet18-f37072fd.pth
```

```bash
torch-model-archiver --model-name resnet-18 --version 1.0 --model-file ./examples/image_classifier/resnet_18/model.py --serialized-file resnet18-f37072fd.pth --handler dali_image_classifier --config-file ./examples/nvidia_dali/model-config.yaml --extra-files ./examples/image_classifier/index_to_name.json
```

Create a new directory `model_store` and move the model-archive file

```bash
mkdir model_store
mv resnet-18.mar model_store/
```

### Start the torchserve

```bash
torchserve --start --model-store model_store --models resnet=resnet-18.mar
jagadeeshi2i marked this conversation as resolved.
Show resolved Hide resolved
```

### Run Inference

Get the inference for a sample image using the below command

```bash
curl http://127.0.0.1:8080/predictions/resnet -T ./examples/image_classifier/kitten.jpg
```

## DALI Pre-Processing in a Custom Handler for mnist model

### Define and Build DALI Pipeline

In DALI, any data processing task has a central object called Pipeline.
Refer to [NVIDIA-DALI](https://github.com/NVIDIA/DALI) for more details on DALI pipeline.

Navigate to `cd ./serve/examples/nvidia_dali`.

Change the `dali_config.json` variables
Change the `model-config.yaml` variables

`batch_size` - Maximum batch size of pipeline.

`num_threads` - Number of CPU threads used by the pipeline.

`device_id` - ID of GPU device used by pipeline.

`pipeline_file` - Pipeline filename

Run the python file which serializes the Dali Pipeline and saves it to `model.dali`

```bash
python serialize_dali_pipeline.py --config dali_config.json
python serialize_dali_pipeline.py --config model-config.yaml
```

**__Note__**:

- Make sure that the serialized file has the extension `.dali`
- Make sure that the serialized file named `model.dali` is created.
- The Torchserve batch size should match the DALI batch size.

### Download the resnet .pth file

```bash
wget https://download.pytorch.org/models/resnet18-f37072fd.pth
```

### Create model-archive file

The following command will create a .mar extension file where we also include the `model.dali` file and `dali_config.json` file in it.
The following command will create a .mar extension file where we also include the `model.dali` file in it.

```bash
torch-model-archiver --model-name resnet-18 --version 1.0 --model-file ../image_classifier/resnet_18/model.py --serialized-file resnet18-f37072fd.pth --handler custom_handler.py --extra-files ../image_classifier/index_to_name.json,./model.dali,./dali_config.json
torch-model-archiver --model-name mnist --version 1.0 --model-file ../image_classifier/mnist/mnist.py --serialized-file ../image_classifier/mnist/mnist_cnn.pt --handler custom_handler.py --extra-files ./model.dali --config-file model-config.yaml
```

Navigate to `serve` directory and run the below commands
Expand All @@ -60,29 +95,19 @@ Create a new directory `model_store` and move the model-archive file

```bash
mkdir model_store
mv resnet-18.mar model_store/
mv mnist.mar model_store/
```

### Start the torchserve

```bash
torchserve --start --model-store model_store --models resnet-18=resnet-18.mar
torchserve --start --model-store model_store --models mnist=mnist.mar
```

### Run Inference

Get the inference for a sample image using the below command

```bash
curl http://127.0.0.1:8080/predictions/resnet-18 -T ./examples/image_classifier/kitten.jpg
```

```json
{
"tabby": 0.408751517534256,
"tiger_cat": 0.35404905676841736,
"Egyptian_cat": 0.12418942898511887,
"lynx": 0.025347290560603142,
"bucket": 0.011393273249268532
}
curl http://127.0.0.1:8080/predictions/mnist -T ./examples/image_classifier/mnist/test_data/0.png
```
77 changes: 23 additions & 54 deletions examples/nvidia_dali/custom_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,43 @@
"""
Base module for all vision handlers
"""
import json
import os
from torch.profiler import ProfilerActivity

import numpy as np
import torch
from nvidia.dali.pipeline import Pipeline
from ts.torch_handler.dali_handler import DALIHandler

from ts.torch_handler.image_classifier import ImageClassifier


class DALIHandler(ImageClassifier):
class DALIMNISTDigitClassifier(DALIHandler):
"""
Base class for all vision handlers
"""

topk = 5
# These are the standard Imagenet dimensions
# and statistics

def __init__(self):
super(DALIHandler, self).__init__()
super(DALIMNISTDigitClassifier, self).__init__()

def initialize(self, context):
super().initialize(context)
properties = context.system_properties
self.model_dir = properties.get("model_dir")
self.profiler_args = {
"activities": [ProfilerActivity.CPU],
"record_shapes": True,
}

self.dali_file = [
file for file in os.listdir(self.model_dir) if file.endswith(".dali")
]
if not len(self.dali_file):
raise RuntimeError("Missing dali pipeline file.")
dali_config_file = os.path.join(self.model_dir, "dali_config.json")
if not os.path.isfile(dali_config_file):
raise RuntimeError("Missing dali_config.json file.")
with open(dali_config_file) as setup_config_file:
self.dali_configs = json.load(setup_config_file)
dali_filename = os.path.join(self.model_dir, self.dali_file[0])
self.pipe = Pipeline.deserialize(
filename=dali_filename,
batch_size=self.dali_configs["batch_size"],
num_threads=self.dali_configs["num_threads"],
prefetch_queue_depth=1,
device_id=self.dali_configs["device_id"],
seed=self.dali_configs["seed"],
)
self.pipe.build()
# pylint: disable=protected-access
self.pipe._max_batch_size = self.dali_configs["batch_size"]
self.pipe._num_threads = self.dali_configs["num_threads"]
self.pipe._device_id = self.dali_configs["device_id"]
def set_max_result_classes(self, topk):
self.topk = topk

def preprocess(self, data):
"""The preprocess function of MNIST program converts the input data to a float tensor
def get_max_result_classes(self):
return self.topk

Args:
data (List): Input data from the request is in the form of a Tensor
def postprocess(self, data):
"""The post process of MNIST converts the predicted output response to a label.

Args:
data (list): The predicted output from the Inference with probabilities is passed
to the post-process function
Returns:
list : The preprocess function returns the input image as a list of float tensors.
list : A list of dictionaries with predictions and explanations is returned
"""
batch_tensor = []
result = []

input_byte_arrays = [i["body"] if "body" in i else i["data"] for i in data]
for byte_array in input_byte_arrays:
np_image = np.frombuffer(byte_array, dtype=np.uint8)
batch_tensor.append(np_image) # we can use numpy

response = self.pipe.run(source=batch_tensor)
for idx, _ in enumerate(response[0]):
data = torch.tensor(response[0].at(idx))
result.append(data.unsqueeze(0))

return torch.cat(result).to(self.device)
return data.argmax(1).tolist()
6 changes: 0 additions & 6 deletions examples/nvidia_dali/dali_config.json

This file was deleted.

1 change: 0 additions & 1 deletion examples/nvidia_dali/index_to_name.json

This file was deleted.

15 changes: 15 additions & 0 deletions examples/nvidia_dali/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#frontend settings
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 200
responseTimeout: 300
deviceType: "gpu"

#backend settings
dali:
batch_size: 5
num_threads: 2
device_id: 0
seed: 12
# For using a custom DALI pipeline, uncomment the below line
# pipeline_file: "model.dali"
6 changes: 0 additions & 6 deletions examples/nvidia_dali/model.py

This file was deleted.

2 changes: 1 addition & 1 deletion examples/nvidia_dali/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
nvidia-dali-cuda110==1.27.0
nvidia-dali-cuda110>=1.27.0
--extra-index-url https://developer.download.nvidia.com/compute/redist
52 changes: 20 additions & 32 deletions examples/nvidia_dali/serialize_dali_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,46 @@
import json
import os

import nvidia.dali as dali
import nvidia.dali.types as types
import yaml


def parse_args():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--save", default="./model.dali")
parser.add_argument("--config", default="dali_config.json")
parser.add_argument("--config", default="model-config.yaml")
return parser.parse_args()


@dali.pipeline_def
# https://docs.nvidia.com/deeplearning/dali/user-guide/docs/operations/nvidia.dali.fn.html
def pipe():
jagadeeshi2i marked this conversation as resolved.
Show resolved Hide resolved
jpegs = dali.fn.external_source(dtype=types.UINT8, name="source", batch=False)
decoded = dali.fn.decoders.image(jpegs, device="mixed")
resized = dali.fn.resize(
decoded,
size=[256],
subpixel_scale=False,
interp_type=types.DALIInterpType.INTERP_LINEAR,
antialias=True,
mode="not_smaller",
)
jpegs = dali.fn.external_source(dtype=types.UINT8, name="source")
decoded = dali.fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
normalized = dali.fn.crop_mirror_normalize(
resized,
crop_pos_x=0.5,
crop_pos_y=0.5,
crop=(224, 224),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
decoded,
mean=[0.1307 * 255],
std=[0.3081 * 255],
)
return normalized


def main(filename):
with open(args.config) as fp:
config = json.load(fp)
batch_size = config["batch_size"]
num_threads = config["num_threads"]
device_id = config["device_id"]
seed = config["seed"]
def main():
config = {}
with open(args.config, "r") as file:
config = yaml.safe_load(file)
batch_size = config["dali"]["batch_size"]
num_threads = config["dali"]["num_threads"]
device_id = config["dali"]["device_id"]
seed = config["dali"]["seed"]
pipeline_filename = config["dali"]["pipeline_file"]

pipeline = pipe(
batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed
)
pipeline.serialize(filename=filename)
print("Saved {}".format(filename))
pipeline.serialize(filename=pipeline_filename)
print("Saved {}".format(pipeline_filename))


if __name__ == "__main__":
args = parse_args()
os.makedirs(os.path.dirname(args.save), exist_ok=True)
main(args.save)
main()
4 changes: 1 addition & 3 deletions kubernetes/kserve/kf_request_json/v2/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ For tensor input, use [totensor](totensor.py) utility
python totensor.py 0.png
```


## Deploying the model in local machine

Start TorchServe
Expand Down Expand Up @@ -94,7 +93,6 @@ Expected Output
{"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298", "model_name": "mnist", "model_version": "1.0", "outputs": [{"name": "predict", "shape": [1], "datatype": "INT64", "data": [0]}]}
```


## Sample request and response for tensor input


Expand All @@ -105,11 +103,11 @@ curl -v -H "Content-Type: application/json" http://localhost:8080/v2/models/mnis
```

Expected output

```bash
{"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298", "model_name": "mnist", "model_version": "1.0", "outputs": [{"name": "predict", "shape": [1], "datatype": "INT64", "data": [0]}]}
```


## Sample request and response for captum

Run the following command
Expand Down
1 change: 1 addition & 0 deletions model-archiver/model_archiver/model_packaging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"image_classifier": "vision",
"object_detector": "vision",
"image_segmenter": "vision",
"dali_image_classifier": "vision",
}

MODEL_SERVER_VERSION = "1.0"
Expand Down
Loading
Loading