diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 9831cfdcb45..7baece2ae2c 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -1,6 +1,7 @@ import functools import torch +import torch._custom_ops import torch.library # Ensure that torch.ops.torchvision is visible @@ -48,3 +49,17 @@ def meta_roi_align_backward( ), ) return grad.new_empty((batch_size, channels, height, width)) + + +@torch._custom_ops.impl_abstract("torchvision::nms") +def meta_nms(dets, scores, iou_threshold): + torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") + torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") + torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") + torch._check( + dets.size(0) == scores.size(0), + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", + ) + ctx = torch._custom_ops.get_ctx() + num_to_keep = ctx.create_unbacked_symint() + return dets.new_empty(num_to_keep, dtype=torch.long)