From f2e6107b56bb921306f4969ff03251e13686e4c1 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 14 Nov 2023 13:19:02 +0900 Subject: [PATCH 1/6] add support for `aten::linalg_vector_norm` --- python/tvm/relay/frontend/pytorch.py | 25 ++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 33 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index bdfd8f78b22e..aff2786cef7f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3844,6 +3844,30 @@ def inplace_copy(self, inputs, input_types): # Return return _op.scatter_nd(source, indices, values, mode) + def linalg_vector_norm(self, inputs, input_types): + data = inputs[0] + dtype = input_types[0] + ord = inputs[1] + dim = inputs[2] + keepdim = inputs[3] + + if ord == 0: + return _op.reduce.sum( + _op.cast(_op.not_equal(data, _expr.const(0, dtype=dtype)), dtype=dtype), + axis=dim, + keepdims=keepdim, + ) + elif ord == np.inf: + return _op.reduce.max(_op.abs(data), axis=dim, keepdims=keepdim) + elif ord == np.NINF: + return _op.reduce.min(_op.abs(data), axis=dim, keepdims=keepdim) + reci_ord = _expr.const(1.0 / ord, dtype=dtype) + ord = _expr.const(ord, dtype=dtype) + return _op.power( + _op.reduce.sum(_op.power(_op.abs(data), ord), axis=dim, keepdims=keepdim), + reci_ord, + ) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -4118,6 +4142,7 @@ def create_convert_map(self): "aten::_weight_norm": self.weight_norm, "aten::copy_": self.inplace_copy, "aten::swapaxes": self.transpose, + "aten::linalg_vector_norm": self.linalg_vector_norm, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 894bea60ed46..e2d5edeaa6aa 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5432,6 +5432,39 @@ def forward(self, *args): verify_model(Swapaxes3().float().eval(), input_data=input_data) +def test_linalg_vector_norm(): + """test_linalg_vector_norm""" + torch.set_grad_enabled(False) + input_shape = [3, 3] + + class VectorNorm1(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x) + + class VectorNorm2(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=3.5) + + class VectorNorm3(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=np.inf) + + class VectorNorm4(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=-np.NINF) + + class VectorNorm5(torch.nn.Module): + def forward(self, x): + return torch.linalg.vector_norm(x, ord=0) + + input_data = torch.rand(input_shape).float() + verify_model(VectorNorm1().float().eval(), input_data=input_data) + verify_model(VectorNorm2().float().eval(), input_data=input_data) + verify_model(VectorNorm3().float().eval(), input_data=input_data) + verify_model(VectorNorm4().float().eval(), input_data=input_data) + verify_model(VectorNorm5().float().eval(), input_data=input_data) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged.""" From 880a9a412751c4160093158408e2c04b646db39f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 15 Nov 2023 15:43:15 +0900 Subject: [PATCH 2/6] add dtype check assertion --- python/tvm/relay/frontend/pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index aff2786cef7f..9d85eacfd040 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3851,6 +3851,8 @@ def linalg_vector_norm(self, inputs, input_types): dim = inputs[2] keepdim = inputs[3] + assert dtype == "float32" or dtype == "float64" + if ord == 0: return _op.reduce.sum( _op.cast(_op.not_equal(data, _expr.const(0, dtype=dtype)), dtype=dtype), From 16977c429e41aefca7be32148a8f89a35fc05c05 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 15 Nov 2023 15:43:33 +0900 Subject: [PATCH 3/6] add double-precision testcase --- tests/python/frontend/pytorch/test_forward.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e2d5edeaa6aa..8e5f985230d5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5464,6 +5464,13 @@ def forward(self, x): verify_model(VectorNorm4().float().eval(), input_data=input_data) verify_model(VectorNorm5().float().eval(), input_data=input_data) + input_data = torch.rand(input_shape).double() + verify_model(VectorNorm1().double().eval(), input_data=input_data) + verify_model(VectorNorm2().double().eval(), input_data=input_data) + verify_model(VectorNorm3().double().eval(), input_data=input_data) + verify_model(VectorNorm4().double().eval(), input_data=input_data) + verify_model(VectorNorm5().double().eval(), input_data=input_data) + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged.""" From d4d97865318dc09b18ce60cef89e958ee4468203 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 15 Nov 2023 15:59:24 +0900 Subject: [PATCH 4/6] Re-enable test_forward_norm and test_forward_frobenius_norm --- tests/python/frontend/pytorch/test_forward.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8e5f985230d5..1c062213225a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1780,7 +1780,6 @@ def forward(self, *args): verify_model(LogSoftmax1().float().eval(), input_data=input_data) -@pytest.mark.skip(reason="unsupported op aten::linalg_vector_norm") @tvm.testing.uses_gpu def test_forward_norm(): """test_forward_norm""" @@ -1840,7 +1839,6 @@ def forward(self, *args): verify_model(Norm10().float().eval(), input_data=input_data) -@pytest.mark.skip(reason="unsupported op aten::linalg_vector_norm") @tvm.testing.uses_gpu def test_forward_frobenius_norm(): """test_forward_frobenius_norm""" From b78c510be30e9059b9adcde3975cf53fc9ab37c2 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 24 Nov 2023 11:30:42 +0900 Subject: [PATCH 5/6] cleanup test --- tests/python/frontend/pytorch/test_forward.py | 43 ++++++------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 1c062213225a..d7cd2532d5ec 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5433,41 +5433,26 @@ def forward(self, *args): def test_linalg_vector_norm(): """test_linalg_vector_norm""" torch.set_grad_enabled(False) - input_shape = [3, 3] - - class VectorNorm1(torch.nn.Module): - def forward(self, x): - return torch.linalg.vector_norm(x) - class VectorNorm2(torch.nn.Module): - def forward(self, x): - return torch.linalg.vector_norm(x, ord=3.5) - - class VectorNorm3(torch.nn.Module): - def forward(self, x): - return torch.linalg.vector_norm(x, ord=np.inf) + def test_fn(ord): + return lambda x: torch.linalg.vector_norm(x, ord) - class VectorNorm4(torch.nn.Module): - def forward(self, x): - return torch.linalg.vector_norm(x, ord=-np.NINF) - - class VectorNorm5(torch.nn.Module): - def forward(self, x): - return torch.linalg.vector_norm(x, ord=0) + input_shape = [3, 3] input_data = torch.rand(input_shape).float() - verify_model(VectorNorm1().float().eval(), input_data=input_data) - verify_model(VectorNorm2().float().eval(), input_data=input_data) - verify_model(VectorNorm3().float().eval(), input_data=input_data) - verify_model(VectorNorm4().float().eval(), input_data=input_data) - verify_model(VectorNorm5().float().eval(), input_data=input_data) + verify_model(test_fn(ord=2.0), input_data=input_data) + verify_model(test_fn(ord=3.5), input_data=input_data) + verify_model(test_fn(ord=np.inf), input_data=input_data) + verify_model(test_fn(ord=np.NINF), input_data=input_data) + verify_model(test_fn(ord=0), input_data=input_data) + # Also test on double input_data = torch.rand(input_shape).double() - verify_model(VectorNorm1().double().eval(), input_data=input_data) - verify_model(VectorNorm2().double().eval(), input_data=input_data) - verify_model(VectorNorm3().double().eval(), input_data=input_data) - verify_model(VectorNorm4().double().eval(), input_data=input_data) - verify_model(VectorNorm5().double().eval(), input_data=input_data) + verify_model(test_fn(ord=2.0), input_data=input_data) + verify_model(test_fn(ord=3.5), input_data=input_data) + verify_model(test_fn(ord=np.inf), input_data=input_data) + verify_model(test_fn(ord=np.NINF), input_data=input_data) + verify_model(test_fn(ord=0), input_data=input_data) class TestSetSpan: From d7223227c7407258fae4323c364cb23abee3c245 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 24 Nov 2023 20:23:43 +0900 Subject: [PATCH 6/6] rename `ord`->`order` to avoid W0622(redefined-builtin) --- tests/python/frontend/pytorch/test_forward.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d7cd2532d5ec..424e30bc2214 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5434,25 +5434,25 @@ def test_linalg_vector_norm(): """test_linalg_vector_norm""" torch.set_grad_enabled(False) - def test_fn(ord): - return lambda x: torch.linalg.vector_norm(x, ord) + def test_fn(order): + return lambda x: torch.linalg.vector_norm(x, ord=order) input_shape = [3, 3] input_data = torch.rand(input_shape).float() - verify_model(test_fn(ord=2.0), input_data=input_data) - verify_model(test_fn(ord=3.5), input_data=input_data) - verify_model(test_fn(ord=np.inf), input_data=input_data) - verify_model(test_fn(ord=np.NINF), input_data=input_data) - verify_model(test_fn(ord=0), input_data=input_data) + verify_model(test_fn(order=2), input_data=input_data) + verify_model(test_fn(order=3.5), input_data=input_data) + verify_model(test_fn(order=np.inf), input_data=input_data) + verify_model(test_fn(order=np.NINF), input_data=input_data) + verify_model(test_fn(order=0), input_data=input_data) # Also test on double input_data = torch.rand(input_shape).double() - verify_model(test_fn(ord=2.0), input_data=input_data) - verify_model(test_fn(ord=3.5), input_data=input_data) - verify_model(test_fn(ord=np.inf), input_data=input_data) - verify_model(test_fn(ord=np.NINF), input_data=input_data) - verify_model(test_fn(ord=0), input_data=input_data) + verify_model(test_fn(order=2), input_data=input_data) + verify_model(test_fn(order=3.5), input_data=input_data) + verify_model(test_fn(order=np.inf), input_data=input_data) + verify_model(test_fn(order=np.NINF), input_data=input_data) + verify_model(test_fn(order=0), input_data=input_data) class TestSetSpan: