Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix batch input - Nvidia DALI #2455

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions examples/nvidia_dali/custom_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import os

import numpy as np
import torch
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

from ts.torch_handler.image_classifier import ImageClassifier

Expand All @@ -31,14 +31,21 @@ def initialize(self, context):
]
if not len(self.dali_file):
raise RuntimeError("Missing dali pipeline file.")
self.PREFETCH_QUEUE_DEPTH = 2
dali_config_file = os.path.join(self.model_dir, "dali_config.json")
if not os.path.isfile(dali_config_file):
raise RuntimeError("Missing dali_config.json file.")
with open(dali_config_file) as setup_config_file:
self.dali_configs = json.load(setup_config_file)
filename = os.path.join(self.model_dir, self.dali_file[0])
self.pipe = Pipeline.deserialize(filename=filename)
dali_filename = os.path.join(self.model_dir, self.dali_file[0])
self.pipe = Pipeline.deserialize(
filename=dali_filename,
batch_size=self.dali_configs["batch_size"],
num_threads=self.dali_configs["num_threads"],
prefetch_queue_depth=1,
device_id=self.dali_configs["device_id"],
seed=self.dali_configs["seed"],
)
self.pipe.build()
# pylint: disable=protected-access
self.pipe._max_batch_size = self.dali_configs["batch_size"]
self.pipe._num_threads = self.dali_configs["num_threads"]
Expand All @@ -54,24 +61,16 @@ def preprocess(self, data):
list : The preprocess function returns the input image as a list of float tensors.
"""
batch_tensor = []
result = []

input_byte_arrays = [i["body"] if "body" in i else i["data"] for i in data]
for byte_array in input_byte_arrays:
np_image = np.frombuffer(byte_array, dtype=np.uint8)
batch_tensor.append(np_image) # we can use numpy

for _ in range(self.PREFETCH_QUEUE_DEPTH):
self.pipe.feed_input("my_source", batch_tensor)

datam = DALIGenericIterator(
[self.pipe],
["data"],
last_batch_policy=LastBatchPolicy.PARTIAL,
last_batch_padded=True,
)
result = []
for _, data in enumerate(datam):
result.append(data[0]["data"])
break
response = self.pipe.run(source=batch_tensor)
for idx, _ in enumerate(response[0]):
data = torch.tensor(response[0].at(idx))
result.append(data.unsqueeze(0))

return result[0].to(self.device)
return torch.cat(result).to(self.device)
3 changes: 2 additions & 1 deletion examples/nvidia_dali/dali_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"batch_size" : 5,
"num_threads" : 2,
"device_id" : 0
"device_id" : 0,
"seed": 12
}
2 changes: 1 addition & 1 deletion examples/nvidia_dali/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
nvidia-dali-cuda110==1.18.0
nvidia-dali-cuda110==1.27.0
--extra-index-url https://developer.download.nvidia.com/compute/redist
9 changes: 6 additions & 3 deletions examples/nvidia_dali/serialize_dali_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def parse_args():

@dali.pipeline_def
def pipe():
jpegs = dali.fn.external_source(dtype=types.UINT8, name="my_source")
jpegs = dali.fn.external_source(dtype=types.UINT8, name="source", batch=False)
decoded = dali.fn.decoders.image(jpegs, device="mixed")
resized = dali.fn.resize(
decoded,
Expand All @@ -43,9 +43,12 @@ def main(filename):
batch_size = config["batch_size"]
num_threads = config["num_threads"]
device_id = config["device_id"]
seed = config["seed"]

pipe1 = pipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id)
pipe1.serialize(filename=filename)
pipeline = pipe(
batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed
)
pipeline.serialize(filename=filename)
print("Saved {}".format(filename))


Expand Down