Skip to content

Commit

Permalink
Set device of NMS op output based on input
Browse files Browse the repository at this point in the history
Signed-off-by: Alankar Mahajan <quic_alanmaha@quicinc.com>
  • Loading branch information
quic-alanmaha authored Mar 5, 2024
1 parent 93ec6bc commit d53ca19
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,10 @@ def forward(self, *args) -> torch.Tensor:
res_per_class = res_per_class[:self.max_output_boxes_per_class]
res.extend(res_per_class)

res = torch.Tensor(res).type(torch.int64)
out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64)
indices = torch.arange(0, len(res) * 3, dtype=torch.int64)
res = torch.tensor(res, dtype=torch.int64, device=args[0].device)
out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3,
dtype=torch.int64, device=args[0].device)
indices = torch.arange(0, len(res) * 3, dtype=torch.int64, device=args[0].device)
out.put_(indices, res)
return out

Expand Down

0 comments on commit d53ca19

Please sign in to comment.