Skip to content

Commit

Permalink
[Fix] Fix reppoints TensorRT support. (#1060)
Browse files Browse the repository at this point in the history
* Fix reppoints

* update todo

* typo fix
  • Loading branch information
q.yao authored Oct 27, 2022
1 parent 197a7ad commit 09add48
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ def reppoints_head__get_bboxes(ctx,
scores = scores.sigmoid()
else:
scores = scores.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)

# TODO: figure out why we can't reshape after permute directly
# TensorRT8.4 would fuse the permute+reshape,
# which leads to incorrect results.
bbox_pred = bbox_pred.permute(0, 2, 3, 1)
bbox_pred = bbox_pred.reshape(batch_size, -1)
bbox_pred = (bbox_pred + 0).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
if pre_topk > 0:
Expand Down

0 comments on commit 09add48

Please sign in to comment.