Skip to content

Commit

Permalink
add support for aten::unflatten
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Nov 15, 2023
1 parent 707492a commit bfbef63
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,16 @@ def flatten(self, inputs, input_types):
out = _op.squeeze(out, axis=squeeze_axes)
return out

def unflatten(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
unflattened_size = tuple(inputs[2])
dshape = get_const_tuple(self.infer_shape_with_prelude(data))
assert len(dshape) > dim
new_shape = dshape[:dim] + unflattened_size + dshape[dim+1:]
out = _op.reshape(data, new_shape)
return out

def addmm(self, inputs, input_types):
input_mat = inputs[0]
mat1 = inputs[1]
Expand Down Expand Up @@ -3945,6 +3955,7 @@ def create_convert_map(self):
"aten::t": self.transpose,
"aten::numpy_T": self.numpy_T,
"aten::flatten": self.flatten,
"aten::unflatten": self.unflatten,
"aten::addmm": self.addmm,
"aten::size": self.size,
"aten::view": self.view,
Expand Down
28 changes: 28 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,34 @@ def _test_flatten(start_dim, end_dim):
verify_model(_test_flatten(-3, -2), inp)


#@tvm.testing.uses_gpu
def test_unflatten():
"""test_unflatten"""

def _test_unflatten(dim, unflattened_size):
return lambda inp: torch.unflatten(inp, dim, unflattened_size)

inp = torch.rand(60)

# [60] -> [3, 5, 2, 2]
verify_model(_test_unflatten(0, (3, 5, 2, 2)), inp)
verify_model(_test_unflatten(0, (-1, 5, 2, 2)), inp)
verify_model(_test_unflatten(0, (3, -1, 2, 2)), inp)
verify_model(_test_unflatten(0, (3, 5, -1, 2)), inp)
verify_model(_test_unflatten(0, (3, 5, 2, -1)), inp)

inp = torch.rand(3, 4, 1)

# [3, 4, 1] -> [3, 2, 2, 1]
verify_model(_test_unflatten(1, (2, 2)), inp)
verify_model(_test_unflatten(1, (-1, 2)), inp)

inp = torch.rand(5, 12, 3)

# [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3]
verify_model(_test_unflatten(-2, (2, 2, 3, 1, 1)), inp)


@tvm.testing.uses_gpu
def test_forward_transpose():
"""test_forward_transpose"""
Expand Down

0 comments on commit bfbef63

Please sign in to comment.