You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.
Checklist
Describe the bug
yolox转onnx后,动态batch size时,给定输入大于1张图像时会报错。我查看了一下代码发现如下
第97行
max_scores, _ = torch.max(scores, 1)
这里我猜测是需要对任意image(即batch 维度),任意位置(即anchor维度),取所有类别sigmoid后分数最大值(再进一步判断当前图像当前anchor是否有必要输出,即是否>conf threshold),但scores是尺寸为(N,Anchors,K)的tensor,max的dim恐怕是-1(即取categories维度上的max)
即修改为max_scores, _ = torch.max(scores, -1)
然后第99行
scores = scores.where(mask, scores.new_zeros(1))
mask可能需要unsqueeze and repeat一下。
即修改为scores = scores.where(torch.unsqueeze(mask, dim=-1).repeat(1,1,K), scores.new_zeros(1))
K=number of categories (不过可能不用repeat,广播应该是OK的)
另外我对照了mmdet里边对应功能的代码 v3.0.0rc6
mmdet/models/dense_heads/yolox_head.py中的predict_by_feat方法第307行,这里应该是想做同样的事情
max_scores, labels = torch.max(flatten_cls_scores[img_id], 1)
但这里max的dim取1是因为前边flatten_cls_scores[img_id]已经取了img_id,相当于N维度已经不在了,所以dim=1对应categories维度。
Reproduction
详见上节
Environment
Error traceback
No response
The text was updated successfully, but these errors were encountered: