Skip to content

Commit

Permalink
[Torch] More graph rewrites for Faster RCNN / MaskRCNN (apache#7346)
Browse files Browse the repository at this point in the history
* add post nms topk to max_out_size rewrite

* add argsort conversion

* scatter pattern first cut

* matching seems to working

* dup matching fixed

* add converter

* conversion seems working

* add reshape, use take

* remove pytorch argsort converter

* update test

* add doc
  • Loading branch information
masahi authored and trevor-m committed Mar 2, 2021
1 parent ee8b49c commit dcc1c77
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 15 deletions.
258 changes: 245 additions & 13 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument, invalid-name
""" Common utilities used by PyTorch frontend """
from .. import expr
from .. import op
from ..dataflow_pattern import (
wildcard,
is_constant,
is_op,
rewrite,
is_tuple,
wildcard,
is_tuple_get_item,
is_if,
DFPatternCallback,
)

Expand All @@ -36,6 +39,19 @@ def is_version_greater_than(ver):
)


def dyn_strided_slice_pattern(inp, end):
"""A pattern to detect dynamic strided slice op."""
zero = is_constant()
cast_like = is_op("cast_like")(zero, is_constant())
less = is_op("less")(is_constant(), cast_like)
shape_of = is_op("shape_of")(inp)
cast_like = is_op("cast_like")(shape_of, is_constant())
add = is_op("add")(is_constant(), cast_like)
where = is_op("where")(less, add, is_constant())

return is_op("dyn.strided_slice")(inp, where, end, is_constant())


def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
"""A pattern to detect batched_nms function in torchvision
Expand Down Expand Up @@ -73,7 +89,6 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
"""
one = is_constant()
zero = is_constant()

# Equivelent PyTorch code from above snippet
# offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
Expand All @@ -84,17 +99,10 @@ def batched_nms(boxes, scores, idxs, iou_threshold):

# The following doesn't appear in the above Relay snippet. It is required for dynamic
# stride_slice handling
cast_like = is_op("cast_like")(zero, is_constant())
less = is_op("less")(is_constant(), cast_like)
shape_of = is_op("shape_of")(mul)
cast_like = is_op("cast_like")(shape_of, is_constant())
add = is_op("add")(is_constant(), cast_like)
where = is_op("where")(less, add, is_constant())
shape_of = is_op("shape_of")(mul)
cast = is_op("cast")(shape_of)

# This corresponds to offsets[:, None], where offsets is the result of multiplication
dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant())
dyn_strided_slice = dyn_strided_slice_pattern(mul, cast)

# Add offsets to the boxes
expand_dims = is_op("expand_dims")(dyn_strided_slice)
Expand All @@ -112,8 +120,49 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
)


class NMSRewrite(DFPatternCallback):
"""A callback to rewrite nms and restore batched nms"""
def topk_after_batch_nms_pattern(cond, true_branch, data, valid_count, indices, iou_threshold):
"""
Detect the following pattern used in torchvision detection models.
def batched_nms(...):
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
else:
...
return nms(boxes_for_nms, scores, iou_threshold)
keep = batched_nms(boxes, scores, lvl, self.nms_thresh)
keep = keep[:post_nms_top_k] # keep only topk scoring predictions
An equivalent Relay subgraph:
%1184 = if (%1117) {
...
} else {
...
%1172 = vision.non_max_suppression(%1167, %1168, %1171, -1, 0.7f, ...);
...
%1183 = dyn.strided_slice(%1174, %1180, %1182, ...);
cast(%1183, dtype="int64")
};
%1185 = strided_slice(%1184, begin=[0], end=[1000], strides=[1]);
"""
nms = is_op("vision.non_max_suppression")(
data, valid_count, indices, is_constant(), iou_threshold
)
indices = is_op("squeeze")(is_tuple_get_item(nms, 0))
size = is_op("squeeze")(is_tuple_get_item(nms, 1))
dyn_strided_slice = dyn_strided_slice_pattern(indices, size)
cast_i64 = is_op("cast")(dyn_strided_slice)

batched_nms_result = is_if(cond, true_branch, cast_i64)

return is_op("strided_slice")(batched_nms_result)


class MulticlassNMSRewrite(DFPatternCallback):
"""A callback to rewrite nms and restore batched nms."""

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -169,10 +218,193 @@ def callback(self, pre, post, node_map):
return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices)


class PostNMSTopKRewrite(DFPatternCallback):
"""A callback to rewrite nms to exploit max_out_size parameter."""

def __init__(self):
super().__init__()
self.cond = wildcard()
self.true_branch = wildcard()
self.data = wildcard()
self.valid_count = wildcard()
self.indices = wildcard()
self.iou_threshold = wildcard()

self.pattern = topk_after_batch_nms_pattern(
self.cond,
self.true_branch,
self.data,
self.valid_count,
self.indices,
self.iou_threshold,
)

def rewrite_batch_nms_with_max_out_size(
self, cond, true_branch, data, valid_count, indices, iou_threshold, post_nms_topk
):
"""Use the detected post NMS topk parameter in NMS op."""
nms_ret = op.vision.non_max_suppression(
data=data,
valid_count=valid_count,
indices=indices,
max_output_size=post_nms_topk,
iou_threshold=iou_threshold,
force_suppress=False,
top_k=-1,
coord_start=2,
score_index=1,
id_index=0,
return_indices=True,
invalid_to_bottom=False,
)

size = op.squeeze(nms_ret[1], axis=[1])
data_slice = op.squeeze(nms_ret[0], axis=[0])

ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, slice_mode="size")

nms_result = op.cast(ret, "int64")

return expr.If(cond, true_branch, nms_result)

def callback(self, pre, post, node_map):
post_nms_topk = post.attrs.end[0].value
return self.rewrite_batch_nms_with_max_out_size(
node_map[self.cond][0],
node_map[self.true_branch][0],
node_map[self.data][0],
node_map[self.valid_count][0],
node_map[self.indices][0],
node_map[self.iou_threshold][0],
post_nms_topk,
)


def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales):
"""Detect the Relay subgraph corresponding to the following PyTorch code
first_result = roi_align_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
first_result.size(2), first_result.size(3)),
dtype=dtype, device=device)
for level in range(len(roi_align_results)):
index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
index = index.expand(index.size(0),
roi_align_results[level].size(1),
roi_align_results[level].size(2),
roi_align_results[level].size(3))
res = res.scatter(0, index, roi_align_results[level])
return res
"""

def do_where(levels, _):
idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant()))
idx_in_level = is_op("split")(idx_in_level)
idx_in_level = is_tuple_get_item(idx_in_level, 0)
idx_in_level = is_op("squeeze")(idx_in_level)
idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0)
return idx_in_level

scatter_res = wildcard()

for i in range(num_scales):
# index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
scatter_indices = do_where(levels, i)
scatter_indices = is_op("reshape")(scatter_indices)

# index = index.expand(index.size(0),
# unmerged_results[level].size(1),
# unmerged_results[level].size(2),
# unmerged_results[level].size(3))
scatter_indices = is_op("repeat")(scatter_indices)
scatter_indices = is_op("repeat")(scatter_indices)
scatter_indices = is_op("repeat")(scatter_indices)

scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i])

return is_op("reshape")(scatter_res)


class ScatterRewrite(DFPatternCallback):
"""A callback to rewrite repeated scatters with a batched gather."""

def __init__(self, num_scales):
super().__init__()
self.num_scales = num_scales
self.levels = wildcard()
self.roi_align_results = []
for _ in range(num_scales):
self.roi_align_results.append(wildcard())

self.pattern = scatter_roi_align_result_pattern(
self.levels, self.roi_align_results, num_scales
)

def convert_scatter_to_gather(self, levels, roi_align_results):
"""Replace the detected scatter loop with the following PyTorch code
indices_per_level = []
for level in range(num_scales):
idx_in_level = torch.where(levels == level)[0]
indices_per_leve.append(idx_in_level)
stacked_features = torch.cat(roi_align_results, dim=0)
stacked_indices = torch.cat(indices_per_level, dim=0)
argsort_indices = torch.argort(stacked_indices)
return stacked_features[argsort_indices, :]
"""

# Collect inidices and concat them
indices_per_level = []
for i in range(self.num_scales):
equal = op.equal(levels, expr.const(i, dtype="int64"))
argwhere = op.argwhere(equal)
split = op.split(argwhere, indices_or_sections=1, axis=1)
squeeze = op.squeeze(split[0], axis=[1])
indices = op.cast(squeeze, dtype="int64")
indices_per_level.append(indices)

indices_concat = op.concatenate(indices_per_level, 0)

# Concat roi align results per level, and argsort indices
# To prepare for a batched gather
roi_align_results_concat = op.concatenate(roi_align_results, 0)
argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64")

# Permute rows by argsorted indices
permuted = op.take(roi_align_results_concat, argsort_indices, axis=0)

return op.reshape(permuted, [0, -1, 1, 1])

def callback(self, pre, post, node_map):
levels = node_map[self.levels][0]
roi_align_results = [node_map[feat][0] for feat in self.roi_align_results]
return self.convert_scatter_to_gather(levels, roi_align_results)


def rewrite_nms_to_batched_nms(mod):
"""Rewrite the input graph to replace non maximum surpression
in torchvision that does not take class id into account with the one
that avoids IOU tests between different classes.
"""
mod["main"] = rewrite(NMSRewrite(), mod["main"])
mod["main"] = rewrite(MulticlassNMSRewrite(), mod["main"])
return mod


def rewrite_batched_nms_with_max_out_size(mod):
"""Rewrite the input graph to detect slicing after batched nms and
use the slicing size as the parameter max_out_size in NMS.
"""
mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"])
return mod


def rewrite_scatter_to_gather(mod, num_scales):
"""Rewrite the input graph to replace a repeated scatter loop with
a batched gather. The scatter loop is used in torchvision MultiScaleRoIAlign
to merge roi_align results for all scales. The scatter is used to emulate
inplace updates.
"""
mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"])
return mod
18 changes: 16 additions & 2 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms
from tvm.relay.frontend.pytorch_utils import (
rewrite_nms_to_batched_nms,
rewrite_batched_nms_with_max_out_size,
rewrite_scatter_to_gather,
)
from tvm.contrib.download import download


Expand Down Expand Up @@ -72,7 +76,7 @@ def generate_jit_model(index):
]

model_func = model_funcs[index]
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=1000))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
Expand Down Expand Up @@ -141,6 +145,16 @@ def compile_and_run_vm(mod, params, data_np, target):
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

before = mod["main"]
mod = rewrite_batched_nms_with_max_out_size(mod)
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

before = mod["main"]
mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")

# Results should be equivalent after rewriting
Expand Down

0 comments on commit dcc1c77

Please sign in to comment.