Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Add NMS Kernel support with Triton Implementation #8746

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
XPU_NOT_AVAILABLE_MSG = "XPU device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."


Expand Down Expand Up @@ -141,6 +142,12 @@ def needs_mps(test_func):
return pytest.mark.needs_mps(test_func)


def needs_xpu(test_func):
import pytest # noqa

return pytest.mark.needs_xpu(test_func)


def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
Expand Down
8 changes: 8 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
XPU_NOT_AVAILABLE_MSG,
)


def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
config.addinivalue_line("markers", "needs_xpu: mark for tests that rely on a XPU device")
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
config.addinivalue_line("markers", "opcheck_only_one: only opcheck one parametrization")

Expand All @@ -43,12 +45,18 @@ def pytest_collection_modifyitems(items):
# and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None
needs_xpu = item.get_closest_marker("needs_xpu") is not None

if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

if needs_xpu and not torch.xpu.is_available():
# In general, we skip xpu tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=XPU_NOT_AVAILABLE_MSG))

if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_qnms(self, iou, scale, zero_point):
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
pytest.param("xpu", marks=pytest.mark.needs_xpu),
),
)
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
Expand Down
50 changes: 50 additions & 0 deletions torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,60 @@ at::Tensor nms_kernel(
return result;
}


/**
* @brief Post-processes the results of the Non-Maximum Suppression (NMS) algorithm.
*
* This function iterates over the boxes and determines which ones to keep based on the IOU (Intersection Over Union) keep-out mask.
* It uses a 32-bitmask to efficiently track and suppress overlapping boxes.
*
* @param order A tensor containing the order of the boxes.
* @param iou_keep_out_mask A tensor containing the IOU keep-out mask. This mask has the shape (N, N//32), where N is the number of boxes.
* The datatype MUST be int32.
* @param num_boxes The total number of boxes.
* @return A tensor containing the indices of the boxes to keep.
*/

at::Tensor nms_kernel_postprocess(
const at::Tensor& order,
const at::Tensor& iou_keep_out_mask,
const int64_t num_boxes) {
// Calculate the number of 32-bit blocks needed to cover all boxes
const int col_blocks = (num_boxes + 32 - 1) / 32;
std::vector<unsigned long> remove_box(col_blocks);
std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks);


at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU));
int64_t * keep_data_ptr = keep.data_ptr<int64_t>();

unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr<int64_t>();
int num_to_keep = 0;
// Note that the iou_keep_out_mask has the shape of (N, N//32)
// The following function iterate over each box to check if it should be kept
for (int64_t i = 0; i < num_boxes; i++) {
int nblock = i / 32;
// This is equivalent to module: 31 - i % 32
int inblock = (31 - i) & (32 -1);

if (!(remove_box[nblock] & (1UL << inblock))){
keep_data_ptr[num_to_keep++]=i;
unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks;
for (int j = nblock; j < col_blocks; j++){
remove_box[j] |= p[j];
}
}
}
return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)});
}



} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms_kernel_postprocess"), TORCH_FN(nms_kernel_postprocess));
}

} // namespace ops
Expand Down
2 changes: 2 additions & 0 deletions torchvision/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.set_python_module("torchvision._meta_registrations");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor"));
}

} // namespace ops
Expand Down
15 changes: 15 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
return keep


def _nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor:
"""
Post-processes the results of the non-maximum suppression (NMS) kernel.
Args:
order (Tensor): A tensor containing the order of the boxes.
iou_keep_out_mask (Tensor): A tensor containing the mask of boxes to keep based on IoU.
The datatype is int32.
num_boxes (int): The number of boxes.
Returns:
Tensor: A tensor containing the post-processed results of the NMS kernel.
"""

return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes)


def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
"""
Clip boxes so that they lie inside an image of size ``size``.
Expand Down
Empty file.
98 changes: 98 additions & 0 deletions torchvision/ops/triton/nms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import triton
import triton.language as tl


@triton.jit
def _combine_bits(val0, val1):
tl.static_assert(val0.dtype == tl.int32, "input must be int32")
tl.static_assert(val1.dtype == tl.int32, "input must be int32")
return val0 | val1


@triton.jit
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr):
"""
This nms_kernel computes the supressed mask of boxes [i, j].
mask[i, j]==1 means if we choose box i, the box j will be supressed.
The output is a mask of size [num_boxes, num_boxes//32], where each item is int32.

Args:
boxes (tl.tensor): A tensor containing the bounding boxes with shape (num_boxes, 4).
output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored.
threshold (float): The IoU threshold for suppressing boxes.
num_boxes (int): The total number of boxes.
stride_i (int): The stride of the output tensor along the first dimension.
stride_j (int): The stride of the output tensor along the second dimension.
BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
Returns:
Tensor (int32): Tensor with size [num_boxes, num_boxes//32]. It indicates that if `box i` is
choosen, whether box `j` could be choosen. The value `1` means it cannot be choosen.
"""

# The Triton kernel is a 2D block kernel. The block size is BLOCK_SIZE x BLOCK_SIZE.
# Each kernel will compute the IoU of boxes[row: row + BLOCK_SIZE, col: col + BLOCK_SIZE]
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)

row_block_start = row_block_pid * BLOCK_SIZE
col_block_start = col_block_pid * BLOCK_SIZE

row_block_offsets = row_block_start + tl.arange(0, BLOCK_SIZE)
col_block_offsets = col_block_start + tl.arange(0, BLOCK_SIZE)

row_block_mask = row_block_offsets < num_boxes
col_block_mask = col_block_offsets < num_boxes

# Since Triton does not support tensor slicing yet, we need to load point elements individiually
# Every row_block is loaded as a 1 dim tensor of size [BLOCK_SIZE]
# We then expand 1 dim for row. So that the row block dim would be [BLOCK_SIZE, 1]
row_block_x1 = tl.load(boxes + row_block_offsets * 4 + 0, mask=row_block_mask)[:, None]
row_block_y1 = tl.load(boxes + row_block_offsets * 4 + 1, mask=row_block_mask)[:, None]
row_block_x2 = tl.load(boxes + row_block_offsets * 4 + 2, mask=row_block_mask)[:, None]
row_block_y2 = tl.load(boxes + row_block_offsets * 4 + 3, mask=row_block_mask)[:, None]

# Expand 1 dim for col. So that the col block dim would be [1, BLOCK_SIZE]
col_block_x1 = tl.load(boxes + col_block_offsets * 4 + 0, mask=col_block_mask)[None, :]
col_block_y1 = tl.load(boxes + col_block_offsets * 4 + 1, mask=col_block_mask)[None, :]
col_block_x2 = tl.load(boxes + col_block_offsets * 4 + 2, mask=col_block_mask)[None, :]
col_block_y2 = tl.load(boxes + col_block_offsets * 4 + 3, mask=col_block_mask)[None, :]

# Together, the minimum / maximum will broadcast and form into a [BLOCK_SIZE, BLOCK_SIZE] matrix
left = tl.maximum(row_block_x1, col_block_x1)
right = tl.minimum(row_block_x2, col_block_x2)
top = tl.maximum(row_block_y1, col_block_y1)
bottom = tl.minimum(row_block_y2, col_block_y2)

width = tl.maximum(right - left, 0)
height = tl.maximum(bottom - top, 0)

intersection = width * height
area_a = (row_block_x2 - row_block_x1) * (row_block_y2 - row_block_y1)
area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1)
union = area_a + area_b - intersection

iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32)

shift_offsets = tl.arange(0, BLOCK_SIZE) % 32
shift_offsets = tl.flip(shift_offsets, 0)[None, :]
shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE])
iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets

# The process of combine bits. Note that the Triton seems having problem when the dtype is int64.
# Thus choosing 32 bits as the mask. And convert it to int64 at the end to avoid further potential overflow.
iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32))
iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits)
iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)

# The bits are combined along the col, thus we need to change the col block offsets
# For the row offset, it will remain the same.
combined_col_blk_offsets = col_block_pid * ((BLOCK_SIZE + 31) // 32)
output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(num_boxes, (num_boxes + 32 - 1) // 32),
strides=(stride_i, stride_j),
offsets=(row_block_start, combined_col_blk_offsets),
block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32),
order=(0, 1),
)
tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1))
Empty file added torchvision/ops/xpu/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions torchvision/ops/xpu/nms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import triton
from torchvision.ops.boxes import _nms_kernel_postprocess

from torchvision.ops.triton.nms import triton_nms_IoU_kernel


@torch.library.register_kernel("torchvision::nms", "xpu")
def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) -> torch.Tensor:
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).

NMS iteratively removes lower scoring boxes which have an
IoU greater than ``iou_threshold`` with another (higher scoring)
box.

If multiple boxes have the exact same score and satisfy the IoU
criterion with respect to a reference box, the selected box is
not guaranteed to be the same between CPU and GPU. This is similar
to the behavior of argsort in PyTorch when repeated values are present.

Args:
boxes (Tensor[N, 4])): boxes to perform NMS on. They
are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
``0 <= y1 < y2``.
scores (Tensor[N]): scores for each one of the boxes
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold

Returns:
Tensor: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
num_boxes = boxes.shape[0]

# Triton does not support argsort yet, thus it needs to fallback to ATen Calls
order = torch.argsort(scores, descending=True)
boxes = boxes[order]
iou_keep_out_mask = torch.zeros(num_boxes, (num_boxes + 32 - 1) // 32, dtype=torch.int64, device=boxes.device)

grid = lambda meta: ( # noqa: E731
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
)

# This triton kernel will calcualte the IoU matrix for all the input boxes (iou_keep_out_mask).
# The iou_keep_out_mask is defined as a 32-bit long bitmask matrix. So the matrix shape is [N, N//32].
# Each item [i, j] will be interpreted as whether we should keep box j when we choose box i.
triton_nms_IoU_kernel[grid](
boxes,
iou_keep_out_mask,
threshold,
num_boxes,
iou_keep_out_mask.stride(0),
iou_keep_out_mask.stride(1),
BLOCK_SIZE=64,
num_warps=4,
)

# The postprocess will calculate the final indices of the boxes that should be kept.
# It is a serialized process, and we choose to run it on CPU for more generalization.
return _nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device)