Skip to content

Commit

Permalink
[Torch] Various updates for PyTorch frontend (apache#7348)
Browse files Browse the repository at this point in the history
* add conversion for detr

* remove explicit broadcast_to before batched matmul

* use take with wrap mode

* add test for transformer and negative indices

* add sort and argsort

* add logical_and

* support masked_select

* add gpu targets to masked_select test

* improve sort conversion
  • Loading branch information
masahi authored and alexwong committed Feb 11, 2021
1 parent b5d9b7f commit 74b401d
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 14 deletions.
63 changes: 50 additions & 13 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,7 @@ def slice(self, inputs, input_types):
begin = [0] * ndim
dim = int(inputs[1])
stride = int(inputs[4])
if isinstance(inputs[2], _expr.Call):
begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
else:
begin[dim] = int(inputs[2])
begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))

# Process begin
if not isinstance(begin[dim], int):
Expand Down Expand Up @@ -518,13 +515,13 @@ def select(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
index = _wrap_const(inputs[2])
return _op.transform.take(data, index, axis=dim)
return _op.transform.take(data, index, axis=dim, mode="wrap")

def take(self, inputs, input_types):
data = inputs[0]
indices = _op.cast(inputs[1], "int32")

return _op.transform.take(data, indices=indices)
return _op.transform.take(data, indices=indices, mode="wrap")

def topk(self, inputs, input_types):
data = inputs[0]
Expand All @@ -551,7 +548,13 @@ def reciprocal(self, inputs, input_types):

def repeat(self, inputs, input_types):
data = inputs[0]
reps = inputs[1]
reps = []
for r in inputs[1]:
if isinstance(r, int):
reps.append(r)
else:
reps.append(int(_infer_value(r, {}).asnumpy()))

return _op.transform.tile(data, reps=reps)

def repeat_interleave(self, inputs, input_types):
Expand Down Expand Up @@ -1520,12 +1523,6 @@ def matmul(self, inputs, input_types):
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
# Broadcast b to match batch size of a
new_b_shape = list(self.infer_shape_with_prelude(b))
new_a_shape = self.infer_shape_with_prelude(a)
if new_a_shape[0] > new_b_shape[0]:
new_b_shape[0] = new_a_shape[0]
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
Expand Down Expand Up @@ -2070,6 +2067,40 @@ def scatter_add(self, inputs, input_types):
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)

def cumsum(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
dtype = inputs[2]

if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])

return _op.cumsum(data, axis=dim, dtype=dtype)

def masked_fill(self, inputs, input_types):
mask = inputs[1]
value = _op.cast(_wrap_const(inputs[2]), input_types[0])
return _op.where(mask, value, inputs[0])

def masked_select(self, inputs, input_types):
mask = inputs[1]
indices = self.nonzero([mask], input_types, is_numpy_style=True)
return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)])

def sort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
is_descending = inputs[2]
# pytorch sort returns both sorted indices and values
indices = _op.argsort(data, dim, not is_descending)
return _op.gather(data, dim, indices), indices

def argsort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
is_descending = inputs[2]
return _op.argsort(data, dim, not is_descending)

def is_floating_point(self, inputs, input_types):
assert len(inputs) == 1

Expand Down Expand Up @@ -2263,6 +2294,7 @@ def create_convert_map(self):
"torchvision::roi_align": self.roi_align,
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
"aten::logical_and": self.logical_and,
"aten::_shape_as_tensor": self.shape_as_tensor,
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
Expand All @@ -2278,6 +2310,11 @@ def create_convert_map(self):
"aten::__not__": self.logical_not,
"aten::hardswish_": self.hard_swish,
"aten::hardswish": self.hard_swish,
"aten::cumsum": self.cumsum,
"aten::masked_fill": self.masked_fill,
"aten::masked_select": self.masked_select,
"aten::argsort": self.argsort,
"aten::sort": self.sort,
}

def update_convert_map(self, custom_map):
Expand Down
101 changes: 100 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def forward(self, *args):
@tvm.testing.uses_gpu
def test_forward_select():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_shape = [5, 3, 10, 10]

class Select1(Module):
def forward(self, *args):
Expand All @@ -1167,6 +1167,9 @@ def forward(self, index):
input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)

# test negative indexing
verify_model(lambda x: x[-1], input_data=input_data)

x = torch.randn(3, 4)
indices = torch.tensor([0, 2])
verify_model(IndexedSelect(x, 0).eval(), input_data=indices)
Expand Down Expand Up @@ -2653,6 +2656,8 @@ def forward(self, *args):
verify_model(Take1().float().eval(), input_data=input_data)
indices = torch.tensor([[0, 0], [1, 0]])
verify_model(Take2().float().eval(), input_data=[input_data, indices])
indices = torch.tensor([0, -1])
verify_model(Take2().float().eval(), input_data=[input_data, indices])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3452,6 +3457,93 @@ def test_hard_swish():
verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input)


def test_cumsum():
def test_fn(dim, dtype=None):
return lambda x: torch.cumsum(x, dim=dim, dtype=dtype)

inp = torch.randint(0, 100, (10000,), dtype=torch.int32)
verify_model(test_fn(0), [inp])
verify_model(test_fn(0), [inp.to(torch.int64)])
verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)])

inp = torch.randn((100, 100), dtype=torch.float32)
verify_model(test_fn(dim=0, dtype=torch.float64), [inp])
verify_model(test_fn(dim=1), [inp])

inp = torch.randn((100, 100), dtype=torch.float32) > 0.5
verify_model(test_fn(dim=0, dtype=torch.int32), [inp])


def test_masked_fill():
def test_fn(x, mask):
return torch.masked_fill(x, mask, 0.0)

inp = torch.randn(100, 100)
verify_model(test_fn, [inp, inp > 0.5])
verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])


def test_transformer():
model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = model.eval()
src = torch.rand((10, 32, 256))
tgt = torch.rand((20, 32, 256))
verify_model(model.eval(), input_data=[src, tgt])


def test_argsort():
def test_fn(dim, descending):
return lambda x: torch.argsort(x, dim=dim, descending=descending)

inp = torch.randn(100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])

inp = torch.randn(100, 100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])
verify_model(test_fn(1, True), [inp])
verify_model(test_fn(1, False), [inp])


def test_sort():
def test_fn(dim, descending):
return lambda x: torch.sort(x, dim=dim, descending=descending)

inp = torch.randn(100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])

inp = torch.randn(100, 100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])
verify_model(test_fn(1, True), [inp])
verify_model(test_fn(1, False), [inp])


def test_logical_and():
def test_fn(x, y):
return torch.logical_and(x, y)

a = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
b = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
verify_model(test_fn, [a, b])

a = torch.tensor([True, False, True])
b = torch.tensor([True, False, False])
verify_model(test_fn, [a, b])


def test_masked_select():
def test_fn(x, mask):
return torch.masked_select(x, mask)

for shape in [(10,), (3, 4), (16, 32, 64)]:
x = torch.randn(*shape)
mask = x.ge(0.5)
verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"])


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
Expand Down Expand Up @@ -3580,6 +3672,13 @@ def test_hard_swish():
test_forward_scatter()
test_numel()
test_bincount()
test_cumsum()
test_masked_fill()
test_transformer()
test_sort()
test_argsort()
test_logical_and()
test_masked_select()

# Model tests
test_resnet18()
Expand Down

0 comments on commit 74b401d

Please sign in to comment.