Skip to content

Commit

Permalink
Merge pull request #60 from jiangjiajun/add-more-cases
Browse files Browse the repository at this point in the history
add more cases for tests
  • Loading branch information
jiangjiajun authored Oct 4, 2021
2 parents 2a5e30d + e600036 commit 4887e96
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 77 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


def _get_pad_size(in_size, dilated_kernel_size, stride_size):
"""Calculate the paddings size."""
"""Calculate the paddings size for Conv/Pool in SAME padding mode."""

if stride_size == 1 or in_size % stride_size == 0:
pad = max(dilated_kernel_size - stride_size, 0)
Expand Down
205 changes: 129 additions & 76 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm.topi.testing
from tvm import relay
from tvm.contrib import graph_executor
import pytest

import paddle
import paddle.nn as nn
Expand Down Expand Up @@ -127,8 +128,6 @@ def add_subtract3(inputs1, inputs2):

@tvm.testing.uses_gpu
def test_forward_arg_max_min():
input_shape = [1, 3, 10, 10]

class ArgMax(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
Expand Down Expand Up @@ -169,32 +168,50 @@ class ArgMin3(nn.Layer):
def forward(self, inputs):
return inputs.argmin(axis=2, keepdim=True)

input_data = paddle.rand(input_shape, dtype="float32")
verify_model(ArgMax(), input_data=input_data)
verify_model(ArgMax1(), input_data=input_data)
verify_model(ArgMax2(), input_data=input_data)
verify_model(ArgMax3(), input_data=input_data)
verify_model(ArgMin(), input_data=input_data)
verify_model(ArgMin1(), input_data=input_data)
verify_model(ArgMin2(), input_data=input_data)
verify_model(ArgMin3(), input_data=input_data)
input_shapes = [[256], [10, 128], [100, 500, 200], [1, 3, 224, 224]]
for input_shape in input_shapes:
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(ArgMax(), input_data=input_data)
verify_model(ArgMin(), input_data=input_data)
for input_shape in input_shapes[1:]:
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(ArgMax1(), input_data=input_data)
verify_model(ArgMax2(), input_data=input_data)
verify_model(ArgMin1(), input_data=input_data)
verify_model(ArgMin2(), input_data=input_data)
for input_shape in input_shapes[2:]:
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(ArgMax3(), input_data=input_data)
verify_model(ArgMin3(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_argsort():
@paddle.jit.to_static
def argsort(inputs):
return paddle.argsort(inputs)
class ArgSort1(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.argsort(inputs)

@paddle.jit.to_static
def argsort2(inputs):
return paddle.argsort(inputs, axis=0, descending=True)
class ArgSort2(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.argsort(inputs, axis=0, descending=True)

input_shape = [2, 3, 5]
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(argsort, input_data)
input_data2 = np.random.randint(100, size=input_shape)
verify_model(argsort2, input_data2)
class ArgSort3(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.argsort(inputs, axis=-1, descending=True)

input_shapes = [[256], [10, 20], [10, 10, 3], [1, 3, 5, 5]]
for input_shape in input_shapes:
# Avoid duplicate elements in the array which will bring
# different results with different sort algorithms
np.random.seed(13)
np_data = np.random.choice(range(-5000, 5000), np.prod(input_shape), replace=False)
input_data = paddle.to_tensor(np_data.reshape(input_shape).astype("int64"))
verify_model(ArgSort1(), [input_data])
verify_model(ArgSort2(), [input_data])
verify_model(ArgSort3(), [input_data])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -291,23 +308,27 @@ def cast2(inputs, dtype="int64"):

@tvm.testing.uses_gpu
def test_forward_check_tensor():
@paddle.jit.to_static
def isfinite(inputs):
return paddle.cast(paddle.isfinite(inputs), "int32")
class IsFinite(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.cast(paddle.isfinite(inputs), "int32")

@paddle.jit.to_static
def isnan(inputs):
return paddle.cast(paddle.isnan(inputs), "int32")
class IsNan(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.cast(paddle.isnan(inputs), "int32")

@paddle.jit.to_static
def isinf(inputs):
return paddle.cast(paddle.isinf(inputs), "int32")
class IsInf(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.cast(paddle.isinf(inputs), "int32")

input_shape = [5, 5]
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(isfinite, input_data=input_data)
verify_model(isnan, input_data=input_data)
verify_model(isinf, input_data=input_data)
input_shapes = [[32], [8, 128], [2, 128, 256], [2, 3, 224, 224], [2, 2, 3, 229, 229]]
for input_shape in input_shapes:
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(IsFinite(), input_data=input_data)
verify_model(IsNan(), input_data=input_data)
verify_model(IsInf(), input_data=input_data)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -391,15 +412,16 @@ def forward(self, inputs):

@tvm.testing.uses_gpu
def test_forward_dot():
@paddle.jit.to_static
def dot(x, y):
return paddle.dot(x, y)
class Dot(nn.Layer):
@paddle.jit.to_static
def forward(self, x, y):
return paddle.dot(x, y)

x_shape = [10, 3]
y_shape = [10, 3]
x_data = paddle.rand(x_shape, dtype="float32")
y_data = paddle.rand(y_shape, dtype="float32")
verify_model(dot, input_data=[x_data, y_data])
input_shapes = [[128], [8, 128]]
for input_shape in input_shapes:
x_data = paddle.rand(input_shape, dtype="float32")
y_data = paddle.rand(input_shape, dtype="float32")
verify_model(Dot(), input_data=[x_data, y_data])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -435,44 +457,70 @@ def forward(self, input1, input2):
api_list = [
"equal",
]
input_shape = [10, 10]
input_shape_2 = [
10,
]
x_data = paddle.randint(1, 10, input_shape, dtype="int32")
y_data = paddle.randint(1, 10, input_shape_2, dtype="int32")
for api_name in api_list:
verify_model(ElemwiseAPI(api_name), [x_data, y_data])
x_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]]
y_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]]
for x_shape, y_shape in zip(x_shapes, y_shapes):
x_data = paddle.randint(1, 1000, x_shape, dtype="int32")
y_data = paddle.randint(1, 1000, y_shape, dtype="int32")
for api_name in api_list:
verify_model(ElemwiseAPI(api_name), [x_data, y_data])


@tvm.testing.uses_gpu
def test_forward_expand():
@paddle.jit.to_static
def expand1(inputs):
return paddle.expand(inputs, shape=[2, 3])
return paddle.expand(inputs, shape=[2, 128])

@paddle.jit.to_static
def expand2(inputs):
shape = paddle.to_tensor(np.array([2, 3]).astype("int32"))
return paddle.expand(inputs, shape=[3, 1, 8, 256])

@paddle.jit.to_static
def expand3(inputs):
return paddle.expand(inputs, shape=[5, 1, 3, 224, 224])

@paddle.jit.to_static
def expand4(inputs):
shape = paddle.to_tensor(np.array([2, 128]).astype("int32"))
return paddle.expand(inputs, shape=shape)

x_shape = [3]
x_data = paddle.rand(x_shape, dtype="float32")
verify_model(expand1, input_data=[x_data])
verify_model(expand2, input_data=[x_data])
@paddle.jit.to_static
def expand5(inputs):
shape = paddle.to_tensor(np.array([3, 1, 8, 256]).astype("int32"))
return paddle.expand(inputs, shape=shape)

@paddle.jit.to_static
def expand6(inputs):
shape = paddle.to_tensor(np.array([5, 1, 3, 224, 224]).astype("int32"))
return paddle.expand(inputs, shape=shape)

data = paddle.rand([128], dtype="float32")
verify_model(expand1, input_data=[data])
verify_model(expand4, input_data=[data])
data = paddle.rand([8, 256], dtype="float32")
verify_model(expand2, input_data=[data])
verify_model(expand5, input_data=[data])
data = paddle.rand([1, 3, 224, 224], dtype="float32")
verify_model(expand3, input_data=[data])
verify_model(expand6, input_data=[data])


@tvm.testing.uses_gpu
def test_forward_expand_as():
@paddle.jit.to_static
def expand_as(x, y):
z = paddle.expand_as(x, y)
z += y
return z
class ExpandAs(nn.Layer):
@paddle.jit.to_static
def forward(self, x, y):
z = paddle.expand_as(x, y)
z += y
return z

data_x = paddle.to_tensor([1, 2, 3], dtype="int32")
data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32")
verify_model(expand_as, [data_x, data_y])
x_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]]
y_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]]
for x_shape, y_shape in zip(x_shapes, y_shapes):
x_data = paddle.rand(x_shape, dtype="float32")
y_data = paddle.rand(y_shape, dtype="float32")
verify_model(ExpandAs(), [x_data, y_data])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -591,11 +639,14 @@ def forward(self, x, y):
z = self.func(x, y, out=out)
return paddle.cast(z, "int32")

x = paddle.to_tensor([True])
y = paddle.to_tensor([True, False, True, False])
verify_model(LogicalAPI("logical_and"), [x, y])
verify_model(LogicalAPI("logical_or"), [x, y])
verify_model(LogicalAPI("logical_xor"), [x, y])
x_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]]
y_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]]
for x_shape, y_shape in zip(x_shapes, y_shapes):
x_data = paddle.randint(0, 2, x_shape).astype("bool")
y_data = paddle.randint(0, 2, y_shape).astype("bool")
verify_model(LogicalAPI("logical_and"), [x_data, y_data])
verify_model(LogicalAPI("logical_or"), [x_data, y_data])
verify_model(LogicalAPI("logical_xor"), [x_data, y_data])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -796,11 +847,13 @@ def forward(self, inputs):
"relu",
"tanh",
]
input_shape = [1, 3, 10, 10]
input_data = paddle.rand(input_shape, dtype="float32")
for api_name in api_list:
verify_model(MathAPI(api_name), input_data=input_data)
input_shapes = [[128], [2, 256], [1000, 128, 32], [7, 3, 256, 256]]
for input_shape in input_shapes:
input_data = paddle.rand(input_shape, dtype="float32")
for api_name in api_list:
verify_model(MathAPI(api_name), input_data=input_data)


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_forward_math_api()

0 comments on commit 4887e96

Please sign in to comment.