diff --git a/examples/nvidia_dali/custom_handler.py b/examples/nvidia_dali/custom_handler.py index bf3337abb0..97c091ca2f 100644 --- a/examples/nvidia_dali/custom_handler.py +++ b/examples/nvidia_dali/custom_handler.py @@ -7,8 +7,8 @@ import os import numpy as np +import torch from nvidia.dali.pipeline import Pipeline -from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy from ts.torch_handler.image_classifier import ImageClassifier @@ -31,14 +31,21 @@ def initialize(self, context): ] if not len(self.dali_file): raise RuntimeError("Missing dali pipeline file.") - self.PREFETCH_QUEUE_DEPTH = 2 dali_config_file = os.path.join(self.model_dir, "dali_config.json") if not os.path.isfile(dali_config_file): raise RuntimeError("Missing dali_config.json file.") with open(dali_config_file) as setup_config_file: self.dali_configs = json.load(setup_config_file) - filename = os.path.join(self.model_dir, self.dali_file[0]) - self.pipe = Pipeline.deserialize(filename=filename) + dali_filename = os.path.join(self.model_dir, self.dali_file[0]) + self.pipe = Pipeline.deserialize( + filename=dali_filename, + batch_size=self.dali_configs["batch_size"], + num_threads=self.dali_configs["num_threads"], + prefetch_queue_depth=1, + device_id=self.dali_configs["device_id"], + seed=self.dali_configs["seed"], + ) + self.pipe.build() # pylint: disable=protected-access self.pipe._max_batch_size = self.dali_configs["batch_size"] self.pipe._num_threads = self.dali_configs["num_threads"] @@ -54,24 +61,16 @@ def preprocess(self, data): list : The preprocess function returns the input image as a list of float tensors. """ batch_tensor = [] + result = [] input_byte_arrays = [i["body"] if "body" in i else i["data"] for i in data] for byte_array in input_byte_arrays: np_image = np.frombuffer(byte_array, dtype=np.uint8) batch_tensor.append(np_image) # we can use numpy - for _ in range(self.PREFETCH_QUEUE_DEPTH): - self.pipe.feed_input("my_source", batch_tensor) - - datam = DALIGenericIterator( - [self.pipe], - ["data"], - last_batch_policy=LastBatchPolicy.PARTIAL, - last_batch_padded=True, - ) - result = [] - for _, data in enumerate(datam): - result.append(data[0]["data"]) - break + response = self.pipe.run(source=batch_tensor) + for idx, _ in enumerate(response[0]): + data = torch.tensor(response[0].at(idx)) + result.append(data.unsqueeze(0)) - return result[0].to(self.device) + return torch.cat(result).to(self.device) diff --git a/examples/nvidia_dali/dali_config.json b/examples/nvidia_dali/dali_config.json index 63611f804f..1102bf7ebc 100644 --- a/examples/nvidia_dali/dali_config.json +++ b/examples/nvidia_dali/dali_config.json @@ -1,5 +1,6 @@ { "batch_size" : 5, "num_threads" : 2, - "device_id" : 0 + "device_id" : 0, + "seed": 12 } diff --git a/examples/nvidia_dali/requirements.txt b/examples/nvidia_dali/requirements.txt index dd79b8db37..169ee9f045 100644 --- a/examples/nvidia_dali/requirements.txt +++ b/examples/nvidia_dali/requirements.txt @@ -1,2 +1,2 @@ -nvidia-dali-cuda110==1.18.0 +nvidia-dali-cuda110==1.27.0 --extra-index-url https://developer.download.nvidia.com/compute/redist diff --git a/examples/nvidia_dali/serialize_dali_pipeline.py b/examples/nvidia_dali/serialize_dali_pipeline.py index 3f01d53723..989b12b0cb 100644 --- a/examples/nvidia_dali/serialize_dali_pipeline.py +++ b/examples/nvidia_dali/serialize_dali_pipeline.py @@ -16,7 +16,7 @@ def parse_args(): @dali.pipeline_def def pipe(): - jpegs = dali.fn.external_source(dtype=types.UINT8, name="my_source") + jpegs = dali.fn.external_source(dtype=types.UINT8, name="source", batch=False) decoded = dali.fn.decoders.image(jpegs, device="mixed") resized = dali.fn.resize( decoded, @@ -43,9 +43,12 @@ def main(filename): batch_size = config["batch_size"] num_threads = config["num_threads"] device_id = config["device_id"] + seed = config["seed"] - pipe1 = pipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id) - pipe1.serialize(filename=filename) + pipeline = pipe( + batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed + ) + pipeline.serialize(filename=filename) print("Saved {}".format(filename))