Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Sep 11, 2020
1 parent db1da32 commit ac009ab
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 10 deletions.
19 changes: 14 additions & 5 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,8 +2063,13 @@ def _impl(inputs, input_types):

output_size = (inputs[3], inputs[4])
spatial_scale = inputs[2]
sample_ratio = inputs[5]
aligned = inputs[6]

return _op.vision.roi_align(data, boxes, output_size, spatial_scale)
if aligned:
data -= _expr.const(0.5)

return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio)
return _impl

def _unbind():
Expand Down Expand Up @@ -2113,12 +2118,15 @@ def _impl(inputs, input_types):
return _op.logical_and(lhs, rhs)
return _impl

def _nonzero():
def _nonzero(is_numpy_style):
def _impl(inputs, input_types):
data = inputs[0]
ret = _op.transform.argwhere(data)
if len(inputs) > 1 and inputs[1]:
ret = _unbind()([ret, 0], None)

if is_numpy_style or (len(inputs) > 1 and inputs[1]):
# TODO(kevinthesun): Support this by adding unbind op
# ret = _unbind()([ret, 0], None)
raise RuntimeError("as_tuple is not supported yet for nonzero.")
return ret
return _impl

Expand Down Expand Up @@ -2485,7 +2493,8 @@ def _get_convert_map(prelude, default_dtype):
"aten::unbind" : _unbind(),
"aten::__and__" : _logical_and(),
"aten::_shape_as_tensor" : _shape_as_tensor(prelude),
"aten::nonzero" : _nonzero(),
"aten::nonzero" : _nonzero(False),
"aten::nonzero_numpy" : _nonzero(True),
"aten::scatter" : _scatter(),
"aten::scalar_tensor" : _scalar_tensor(),
"aten::__interpolate" : _interpolate(),
Expand Down
87 changes: 82 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,33 @@ def _gen_rand_inputs(num_boxes):
verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores], targets)


def test_forward_roi_align():
"""ROI align"""
torch.set_grad_enabled(False)
class ROIAlgin(Module):
def __init__(self, output_sizes, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
super().__init__()
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
self.output_sizes = output_sizes

def forward(self, *args):
return torchvision.ops.roi_align(args[0], args[1], self.output_sizes,
self.spatial_scale, self.sampling_ratio,
self.aligned)

in_data = torch.Tensor(np.random.uniform(size=(1, 8, 100, 100)))
in_boxes = torch.Tensor(np.random.uniform(0.0, 100.0, size=(35, 4)))
in_batch = torch.zeros((35, 1), dtype=torch.float)
in_boxes = torch.cat([in_batch, in_boxes], dim=1)


verify_model(ROIAlgin(7), [in_data, in_boxes])
verify_model(ROIAlgin((10, 10), 0.7, 5), [in_data, in_boxes])
verify_model(ROIAlgin(15, 0.9, 3, False), [in_data, in_boxes])


@tvm.testing.uses_gpu
def test_conv3d():
for ishape in [(1, 32, 16, 16, 16),
Expand Down Expand Up @@ -1661,8 +1688,8 @@ def test_conv3d_transpose():
padding=(0, 4, 2)).eval(),
inp),
verify_model(torch.nn.ConvTranspose3d(in_channels=8,
out_channels=20,
kernel_size=1).eval(),
out_channels=20,
kernel_size=1).eval(),
inp)
verify_model(torch.nn.ConvTranspose3d(in_channels=8,
out_channels=5,
Expand Down Expand Up @@ -2954,6 +2981,54 @@ def forward(self, x):
verify_script_model(Stack(), [(8, 8, 8)], _get_default_vm_targets())


def test_forward_unbind():
class Unbind(torch.nn.Module):
def __init__(self, axis=0):
super().__init__()
self.axis = axis

def forward(self, x):
return torch.unbind(x, self.axis)

inp = torch.randn(8, 8, 8)
verify_model(Unbind(0), input_data=inp)
verify_model(Unbind(1), input_data=inp)
verify_model(Unbind(2), input_data=inp)


def test_forward_nonzero():
class Nonzero(Module):
def __init__(self, as_tuple=False):
super().__init__()
self.as_tuple = as_tuple

def forward(self, data):
return torch.nonzero(data, as_tuple=self.as_tuple)

inp = torch.Tensor(np.array([[0, 1, 0], [2, 0, 9], [-1, -1, 0]]).astype("float32"))
verify_trace_model(Nonzero(), [inp], ['llvm'])


def test_forward_scatter():
class Scatter(Module):
def __init__(self, dim=0):
super().__init__()
self.dim = dim

def forward(self, data, index, src):
return torch.scatter(data, dim=self.dim, index=index, src=src)

in_data = torch.zeros(3, 5)
in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
in_src = torch.rand(2, 5)
verify_model(Scatter(), input_data=[in_data, in_index, in_src])

in_data = torch.zeros(2, 4)
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1)
verify_model(Scatter(1), input_data=[in_data, in_index, in_src])


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
Expand Down Expand Up @@ -3083,7 +3158,6 @@ def test_forward_pretrained_bert_base_uncased():


if __name__ == "__main__":
"""
# some structural tests
test_forward_traced_function()
test_forward_dtypes()
Expand Down Expand Up @@ -3180,6 +3254,7 @@ def test_forward_pretrained_bert_base_uncased():
test_upsample()
test_forward_upsample3d()
test_forward_nms()
test_forward_roi_align()
test_to()
test_flatten()
test_type_as()
Expand All @@ -3201,6 +3276,9 @@ def test_forward_pretrained_bert_base_uncased():
test_logsumexp()
test_stack()
test_stack_dynamic()
test_forward_unbind()
test_forward_nonzero()
test_forward_scatter()

# Model tests
test_resnet18()
Expand Down Expand Up @@ -3228,10 +3306,9 @@ def test_forward_pretrained_bert_base_uncased():
# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()
"""

# More complex recurrent models
from lstm_test import test_custom_lstm
from test_lstm import test_custom_lstm

test_custom_lstm()

Expand Down

0 comments on commit ac009ab

Please sign in to comment.