diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 6fc5a6af4a36..248f5354cfbb 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -16,13 +16,16 @@ # under the License. # pylint: disable=import-outside-toplevel, unused-argument, invalid-name """ Common utilities used by PyTorch frontend """ +from .. import expr from .. import op from ..dataflow_pattern import ( + wildcard, is_constant, is_op, rewrite, is_tuple, - wildcard, + is_tuple_get_item, + is_if, DFPatternCallback, ) @@ -36,6 +39,19 @@ def is_version_greater_than(ver): ) +def dyn_strided_slice_pattern(inp, end): + """A pattern to detect dynamic strided slice op.""" + zero = is_constant() + cast_like = is_op("cast_like")(zero, is_constant()) + less = is_op("less")(is_constant(), cast_like) + shape_of = is_op("shape_of")(inp) + cast_like = is_op("cast_like")(shape_of, is_constant()) + add = is_op("add")(is_constant(), cast_like) + where = is_op("where")(less, add, is_constant()) + + return is_op("dyn.strided_slice")(inp, where, end, is_constant()) + + def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices): """A pattern to detect batched_nms function in torchvision @@ -73,7 +89,6 @@ def batched_nms(boxes, scores, idxs, iou_threshold): """ one = is_constant() - zero = is_constant() # Equivelent PyTorch code from above snippet # offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) @@ -84,17 +99,10 @@ def batched_nms(boxes, scores, idxs, iou_threshold): # The following doesn't appear in the above Relay snippet. It is required for dynamic # stride_slice handling - cast_like = is_op("cast_like")(zero, is_constant()) - less = is_op("less")(is_constant(), cast_like) - shape_of = is_op("shape_of")(mul) - cast_like = is_op("cast_like")(shape_of, is_constant()) - add = is_op("add")(is_constant(), cast_like) - where = is_op("where")(less, add, is_constant()) shape_of = is_op("shape_of")(mul) cast = is_op("cast")(shape_of) - # This corresponds to offsets[:, None], where offsets is the result of multiplication - dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant()) + dyn_strided_slice = dyn_strided_slice_pattern(mul, cast) # Add offsets to the boxes expand_dims = is_op("expand_dims")(dyn_strided_slice) @@ -112,8 +120,49 @@ def batched_nms(boxes, scores, idxs, iou_threshold): ) -class NMSRewrite(DFPatternCallback): - """A callback to rewrite nms and restore batched nms""" +def topk_after_batch_nms_pattern(cond, true_branch, data, valid_count, indices, iou_threshold): + """ + Detect the following pattern used in torchvision detection models. + + def batched_nms(...): + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + else: + ... + return nms(boxes_for_nms, scores, iou_threshold) + + keep = batched_nms(boxes, scores, lvl, self.nms_thresh) + keep = keep[:post_nms_top_k] # keep only topk scoring predictions + + An equivalent Relay subgraph: + + %1184 = if (%1117) { + ... + } else { + ... + %1172 = vision.non_max_suppression(%1167, %1168, %1171, -1, 0.7f, ...); + ... + %1183 = dyn.strided_slice(%1174, %1180, %1182, ...); + cast(%1183, dtype="int64") + }; + %1185 = strided_slice(%1184, begin=[0], end=[1000], strides=[1]); + + """ + nms = is_op("vision.non_max_suppression")( + data, valid_count, indices, is_constant(), iou_threshold + ) + indices = is_op("squeeze")(is_tuple_get_item(nms, 0)) + size = is_op("squeeze")(is_tuple_get_item(nms, 1)) + dyn_strided_slice = dyn_strided_slice_pattern(indices, size) + cast_i64 = is_op("cast")(dyn_strided_slice) + + batched_nms_result = is_if(cond, true_branch, cast_i64) + + return is_op("strided_slice")(batched_nms_result) + + +class MulticlassNMSRewrite(DFPatternCallback): + """A callback to rewrite nms and restore batched nms.""" def __init__(self): super().__init__() @@ -169,10 +218,193 @@ def callback(self, pre, post, node_map): return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices) +class PostNMSTopKRewrite(DFPatternCallback): + """A callback to rewrite nms to exploit max_out_size parameter.""" + + def __init__(self): + super().__init__() + self.cond = wildcard() + self.true_branch = wildcard() + self.data = wildcard() + self.valid_count = wildcard() + self.indices = wildcard() + self.iou_threshold = wildcard() + + self.pattern = topk_after_batch_nms_pattern( + self.cond, + self.true_branch, + self.data, + self.valid_count, + self.indices, + self.iou_threshold, + ) + + def rewrite_batch_nms_with_max_out_size( + self, cond, true_branch, data, valid_count, indices, iou_threshold, post_nms_topk + ): + """Use the detected post NMS topk parameter in NMS op.""" + nms_ret = op.vision.non_max_suppression( + data=data, + valid_count=valid_count, + indices=indices, + max_output_size=post_nms_topk, + iou_threshold=iou_threshold, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + + size = op.squeeze(nms_ret[1], axis=[1]) + data_slice = op.squeeze(nms_ret[0], axis=[0]) + + ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, slice_mode="size") + + nms_result = op.cast(ret, "int64") + + return expr.If(cond, true_branch, nms_result) + + def callback(self, pre, post, node_map): + post_nms_topk = post.attrs.end[0].value + return self.rewrite_batch_nms_with_max_out_size( + node_map[self.cond][0], + node_map[self.true_branch][0], + node_map[self.data][0], + node_map[self.valid_count][0], + node_map[self.indices][0], + node_map[self.iou_threshold][0], + post_nms_topk, + ) + + +def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales): + """Detect the Relay subgraph corresponding to the following PyTorch code + + first_result = roi_align_results[0] + dtype, device = first_result.dtype, first_result.device + res = torch.zeros((levels.size(0), first_result.size(1), + first_result.size(2), first_result.size(3)), + dtype=dtype, device=device) + for level in range(len(roi_align_results)): + index = torch.where(levels == level)[0].view(-1, 1, 1, 1) + index = index.expand(index.size(0), + roi_align_results[level].size(1), + roi_align_results[level].size(2), + roi_align_results[level].size(3)) + res = res.scatter(0, index, roi_align_results[level]) + return res + """ + + def do_where(levels, _): + idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant())) + idx_in_level = is_op("split")(idx_in_level) + idx_in_level = is_tuple_get_item(idx_in_level, 0) + idx_in_level = is_op("squeeze")(idx_in_level) + idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0) + return idx_in_level + + scatter_res = wildcard() + + for i in range(num_scales): + # index = torch.where(levels == level)[0].view(-1, 1, 1, 1) + scatter_indices = do_where(levels, i) + scatter_indices = is_op("reshape")(scatter_indices) + + # index = index.expand(index.size(0), + # unmerged_results[level].size(1), + # unmerged_results[level].size(2), + # unmerged_results[level].size(3)) + scatter_indices = is_op("repeat")(scatter_indices) + scatter_indices = is_op("repeat")(scatter_indices) + scatter_indices = is_op("repeat")(scatter_indices) + + scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i]) + + return is_op("reshape")(scatter_res) + + +class ScatterRewrite(DFPatternCallback): + """A callback to rewrite repeated scatters with a batched gather.""" + + def __init__(self, num_scales): + super().__init__() + self.num_scales = num_scales + self.levels = wildcard() + self.roi_align_results = [] + for _ in range(num_scales): + self.roi_align_results.append(wildcard()) + + self.pattern = scatter_roi_align_result_pattern( + self.levels, self.roi_align_results, num_scales + ) + + def convert_scatter_to_gather(self, levels, roi_align_results): + """Replace the detected scatter loop with the following PyTorch code + + indices_per_level = [] + for level in range(num_scales): + idx_in_level = torch.where(levels == level)[0] + indices_per_leve.append(idx_in_level) + + stacked_features = torch.cat(roi_align_results, dim=0) + stacked_indices = torch.cat(indices_per_level, dim=0) + argsort_indices = torch.argort(stacked_indices) + return stacked_features[argsort_indices, :] + """ + + # Collect inidices and concat them + indices_per_level = [] + for i in range(self.num_scales): + equal = op.equal(levels, expr.const(i, dtype="int64")) + argwhere = op.argwhere(equal) + split = op.split(argwhere, indices_or_sections=1, axis=1) + squeeze = op.squeeze(split[0], axis=[1]) + indices = op.cast(squeeze, dtype="int64") + indices_per_level.append(indices) + + indices_concat = op.concatenate(indices_per_level, 0) + + # Concat roi align results per level, and argsort indices + # To prepare for a batched gather + roi_align_results_concat = op.concatenate(roi_align_results, 0) + argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64") + + # Permute rows by argsorted indices + permuted = op.take(roi_align_results_concat, argsort_indices, axis=0) + + return op.reshape(permuted, [0, -1, 1, 1]) + + def callback(self, pre, post, node_map): + levels = node_map[self.levels][0] + roi_align_results = [node_map[feat][0] for feat in self.roi_align_results] + return self.convert_scatter_to_gather(levels, roi_align_results) + + def rewrite_nms_to_batched_nms(mod): """Rewrite the input graph to replace non maximum surpression in torchvision that does not take class id into account with the one that avoids IOU tests between different classes. """ - mod["main"] = rewrite(NMSRewrite(), mod["main"]) + mod["main"] = rewrite(MulticlassNMSRewrite(), mod["main"]) + return mod + + +def rewrite_batched_nms_with_max_out_size(mod): + """Rewrite the input graph to detect slicing after batched nms and + use the slicing size as the parameter max_out_size in NMS. + """ + mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"]) + return mod + + +def rewrite_scatter_to_gather(mod, num_scales): + """Rewrite the input graph to replace a repeated scatter loop with + a batched gather. The scatter loop is used in torchvision MultiScaleRoIAlign + to merge roi_align results for all scales. The scatter is used to emulate + inplace updates. + """ + mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"]) return mod diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 2c323776f087..fd33dd1da8b1 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -26,7 +26,11 @@ import tvm.testing from tvm import relay from tvm.runtime.vm import VirtualMachine -from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms +from tvm.relay.frontend.pytorch_utils import ( + rewrite_nms_to_batched_nms, + rewrite_batched_nms_with_max_out_size, + rewrite_scatter_to_gather, +) from tvm.contrib.download import download @@ -72,7 +76,7 @@ def generate_jit_model(index): ] model_func = model_funcs[index] - model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200)) + model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=1000)) model.eval() inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) @@ -141,6 +145,16 @@ def compile_and_run_vm(mod, params, data_np, target): after = mod["main"] assert not tvm.ir.structural_equal(after, before) + before = mod["main"] + mod = rewrite_batched_nms_with_max_out_size(mod) + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) + + before = mod["main"] + mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) + tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm") # Results should be equivalent after rewriting