From bf9ada6fc9a53d97f22213b07fb7703fa8dfa68e Mon Sep 17 00:00:00 2001 From: Alankar Mahajan Date: Fri, 1 Mar 2024 18:37:53 +0530 Subject: [PATCH] Set device of NMS op output based on input Signed-off-by: Alankar Mahajan --- .../torch/src/python/aimet_torch/elementwise_ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index 4445ab02244..4337bf85fbb 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -388,9 +388,10 @@ def forward(self, *args) -> torch.Tensor: res_per_class = res_per_class[:self.max_output_boxes_per_class] res.extend(res_per_class) - res = torch.Tensor(res).type(torch.int64) - out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64) - indices = torch.arange(0, len(res) * 3, dtype=torch.int64) + res = torch.tensor(res, dtype=torch.int64, device=args[0].device) + out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, + dtype=torch.int64, device=args[0].device) + indices = torch.arange(0, len(res) * 3, dtype=torch.int64, device=args[0].device) out.put_(indices, res) return out