From 4520515996dba25662644a3160ce6ca661d9e6d4 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Wed, 11 Aug 2021 20:12:10 +0800 Subject: [PATCH 1/4] add test case and fix bug --- python/tvm/relay/frontend/paddlepaddle.py | 94 +++- .../frontend/paddlepaddle/test_forward.py | 466 ++++++++++++++++-- 2 files changed, 497 insertions(+), 63 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 3d63826b6ea2..02f256ab3435 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -19,6 +19,7 @@ """Paddle: PArallel Distributed Deep LEarning.""" import copy import warnings +from pandas.core.dtypes.inference import is_scalar import six import numpy as np @@ -68,10 +69,13 @@ def convert_arg_max(g, op, block): axis = op.attr('axis') keepdims = op.attr('keepdims') flatten = op.attr('flatten') - assert not flatten, "Only flatten==True is supported for PaddlePaddle's arg_max" - x = g.get_node(x.input('X')[0]) - out = _op.argmax(x, axis=axis, keepdims=keepdims) + x = g.get_node(op.input('X')[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + out = _op.argmax(x, axis=None, keepdims=True) + else: + out = _op.argmax(x, axis=axis, keepdims=keepdims) g.add_node(op.output('Out')[0], out) @@ -166,9 +170,9 @@ def convert_cumsum(g, op, block): flatten = op.attr('flatten') reverse = op.attr('reverse') - assert not flatten, "Only flatten==False is supported for PaddlePaddle's cumsum" - x = g.get_node(op.input('X')[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) if reverse: x = _op.reverse(x, axis=axis) out = _op.cumsum(x, axis=axis, exclusive=exclusive) @@ -281,6 +285,12 @@ def convert_fill_constant(g, op, block): shape = block.var(op.output('Out')[0]).shape dtype = block.var(op.output('Out')[0]).dtype dtype = str(dtype).strip().split('.')[1] + if op.input('ValueTensor'): + shape = g.get_node(op.input('ValueTensor')[0]) + shape = infer_value(shape, g.get_params()).numpy() + if op.input('ShapeTensor'): + shape = g.get_node(op.input('ShapeTensor')[0]) + shape = infer_value(shape, g.get_params()).numpy() value = np.full(shape, value, dtype) out = _expr.const(value.astype(dtype)).astype(dtype) g.add_node(op.output('Out')[0], out) @@ -333,8 +343,22 @@ def convert_layer_norm(g, op, block): begin_norm_axis = op.attr('begin_norm_axis') epsilon = op.attr('epsilon') x = g.get_node(op.input('X')[0]) - bias = g.get_node(op.input('Bias')[0]) - scale = g.get_node(op.input('Scale')[0]) + bias_input = op.input('Bias') + scale_input = op.input('Scale') + + x_shape = infer_shape(x) + assert begin_norm_axis == -1 or begin_norm_axis == len(x_shape) - 1, "Support only normalization over last one dimension." + + if bias_input: + bias = g.get_node(bias_input[0]) + else: + bias = _expr.const(np.zeros(x_shape[begin_norm_axis])) + + if scale_input: + scale = g.get_node(scale_input[0]) + else: + scale = _expr.const(np.ones(x_shape[begin_norm_axis])) + out = _op.nn.layer_norm(x, gamma=scale, beta=bias, @@ -351,7 +375,7 @@ def convert_leaky_relu(g, op, block): alpha = op.attr('alpha') x = g.get_node(op.input('X')[0]) out = _op.nn.leaky_relu(x, alpha=alpha) - g.add_node(op.output('Out')[0]) + g.add_node(op.output('Out')[0], out) def convert_lookup_table(g, op, block): @@ -540,7 +564,7 @@ def convert_pool2d(g, op, block): if global_pooling: adaptive = True ksize = [1, 1] - + input = g.get_node(op.input('X')[0]) in_h, in_w = infer_shape(input)[2:] @@ -587,8 +611,27 @@ def convert_pool2d(g, op, block): def convert_reshape(g, op, block): """Operator converter for reshape.""" - shape = op.attr('shape') - out = _op.reshape(g.get_node(op.input('X')[0]), shape) + shape_attr = op.input('Shape') + tensor_attr = op.input('ShapeTensor') + data = g.get_node(op.input('X')[0]) + if shape_attr: + new_shape = g.get_node(shape_attr[0]) + elif tensor_attr: + tmp_shape = [] + for shape_name in tensor_attr: + shape = g.get_node(shape_name) + if len(infer_shape(shape)) == 0: + shape = _op.reshape(shape, [-1]) + if isinstance(shape, _expr.Constant): + tmp_shape.append(shape) + elif isinstance(shape, _expr.Expr): + tmp_shape.append(shape) + else: + tmp_shape.append(_expr.const(np.array(shape).astype('int64'))) + new_shape = _op.concatenate(tmp_shape, axis=0) + else: + new_shape = op.attr('shape') + out = _op.reshape(data, new_shape) g.add_node(op.output('Out')[0], out) @@ -627,7 +670,7 @@ def convert_shape(g, op, block): def convert_slice(g, op, block): """Operator converter for slice.""" - def parameter_process(starts, ends, axes): + def parameter_process(starts, ends, axes, dshape): new_axes = [] new_starts = [] new_ends = [] @@ -640,22 +683,29 @@ def parameter_process(starts, ends, axes): pop_index += 1 else: new_starts.append(0) - new_ends.append(np.iinfo(np.int32).max) + new_ends.append(dshape[i]) return new_starts, new_ends, new_axes + data = g.get_node(op.input('Input')[0]) + dshape = infer_shape(data) starts = op.attr('starts') ends = op.attr('ends') axes = op.attr('axes') + decrease_axis = op.attr('decrease_axis') if isinstance(starts, int): starts = [starts] if isinstance(ends, int): ends = [ends] if isinstance(axes, int): axes = [axes] - starts, ends, axes = parameter_process(starts, ends, axes) - out = _op.strided_slice(g.get_node(op.input('Input')[0]), + if isinstance(decrease_axis, int): + decrease_axis = [decrease_axis] + starts, ends, axes = parameter_process(starts, ends, axes, dshape) + out = _op.strided_slice(data, begin=starts, end=ends) + if decrease_axis: + out = _op.squeeze(out, axis=decrease_axis) g.add_node(op.output('Out')[0], out) @@ -672,15 +722,6 @@ def convert_softmax(g, op, block): out = e / _op.sum(e, axis, keepdims=True) g.add_node(op.output('Out')[0], out) - -def convert_transpose(g, op, block): - """Operator converter for transpose.""" - - perm = op.attr('axis') - out = _op.transpose(g.get_node(op.input('X')[0]), axes=perm) - g.add_node(op.output('Out')[0], out) - - def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -727,7 +768,6 @@ def convert_unsqueeze(g, op, block): 'slice': convert_slice, 'softmax': convert_softmax, 'tanh': convert_activation, - 'transpose2': convert_transpose, 'unsqueeze2': convert_unsqueeze, } @@ -747,7 +787,9 @@ def get_node(self, name): def add_node(self, name, node): self.nodes[name] = fold_constant(node) - def get_params(self, name): + def get_params(self, name=None): + if name is None: + return self.params assert name in self.params return self.params[name] diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 94a2c468b21a..d4dea6cbe59a 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -19,7 +19,6 @@ import shutil import numpy as np -from paddle.fluid.layers.nn import pad import tvm import tvm.testing import tvm.topi.testing @@ -48,38 +47,45 @@ def get_paddle_model(func, input_spec): return baseline_model def verify_model(func, input_data, rtol=1e-5, atol=1e-5): - if not (isinstance(input_data, list) or isinstance(input_data, tuple)): + if not (isinstance(input_data, (tuple, list))): input_data = [input_data] input_spec = [] input_names = [] input_shape_dict = {} + compiled_input = {} for idx, data in enumerate(input_data): input_name = "input{}".format(idx) input_spec.append(paddle.static.InputSpec(dtype=data.dtype, shape=data.shape, name=input_name)) input_names.append(input_name) input_shape_dict[input_name] = data.shape + if isinstance(data, np.ndarray): + compiled_input[input_name] = data + else: + compiled_input[input_name] = data.numpy() baseline_model = get_paddle_model(func, input_spec) - baseline_outputs = baseline_model(*[input.clone() for input in input_data]) + baseline_outputs = baseline_model(*[input[:] for input in input_data]) # get paddle outputs - if isinstance(baseline_outputs, tuple): + if isinstance(baseline_outputs, (tuple, list)): baseline_outputs = tuple(out.numpy() for out in baseline_outputs) else: baseline_outputs = (baseline_outputs.numpy(),) mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) - for arg in mod["main"].params[: len(input_names)]: + parms_num = min(len(input_names), len(mod["main"].params)) + compiled_names = [] + for arg in mod["main"].params[: parms_num]: assert arg.name_hint in input_names - compiled_input = dict(zip(input_names, [inp.clone().numpy() for inp in input_data])) + compiled_names.append(arg.name_hint) with tvm.transform.PassContext(opt_level=3): for target, dev in tvm.testing.enabled_targets(): lib = relay.build(mod, target=target, params=params) gmod = graph_executor.GraphModule(lib["default"](dev)) - for name, inp in compiled_input.items(): - gmod.set_input(name, inp) + for name in compiled_names: + gmod.set_input(name, compiled_input[name]) gmod.run() for i, baseline_output in enumerate(baseline_outputs): @@ -112,26 +118,310 @@ def add_subtract3(inputs1, inputs2): verify_model(add_subtract3, [input_data, input_data2]) @tvm.testing.uses_gpu -def test_forward_multiply(): - input_shape = [10] +def test_forward_argmax(): + input_shape = [1, 3, 10, 10] + + class ArgMax(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argmax(inputs) + + class ArgMax1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1) + + class ArgMax2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1, keepdim=False) + + class ArgMax3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=2, keepdim=True) + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax(), input_data=input_data) + verify_model(ArgMax1(), input_data=input_data) + verify_model(ArgMax2(), input_data=input_data) + verify_model(ArgMax3(), input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_assign(): + @paddle.jit.to_static + def assign(inputs): + return paddle.assign(inputs) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(assign, [input_data,]) + input_data2 = np.random.randint(100, size=input_shape) + verify_model(assign, [input_data2,]) + +@tvm.testing.uses_gpu +def test_forward_batch_norm(): + class BatchNorm1D(nn.Layer): + def __init__(self): + super(BatchNorm1D, self).__init__() + self.batch_norm = nn.BatchNorm1D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + class BatchNorm2D(nn.Layer): + def __init__(self): + super(BatchNorm2D, self).__init__() + self.batch_norm = nn.BatchNorm2D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + class BatchNorm3D(nn.Layer): + def __init__(self): + super(BatchNorm3D, self).__init__() + self.batch_norm = nn.BatchNorm3D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + input_data = paddle.rand((2, 2, 3), dtype="float32") + verify_model(BatchNorm1D(), input_data=input_data) + input_data = paddle.rand((2, 2, 2, 3), dtype="float32") + verify_model(BatchNorm2D(), input_data=input_data) + input_data = paddle.rand((2, 2, 2, 2, 3), dtype="float32") + verify_model(BatchNorm3D(), input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_cast(): + @paddle.jit.to_static + def cast1(inputs, dtype="uint8"): + return paddle.cast(inputs, dtype) + + @paddle.jit.to_static + def cast2(inputs, dtype="int64"): + return inputs.cast(dtype) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") * 100 + verify_model(cast1, [input_data,]) + verify_model(cast2, [input_data,]) + +@tvm.testing.uses_gpu +def test_forward_concat_unsqueeze(): + @paddle.jit.to_static + def concat_unsqueeze1(inputs): + return paddle.concat([inputs[:, 0].unsqueeze(1), inputs[:, 1].unsqueeze(1)], axis=1) + + @paddle.jit.to_static + def concat_unsqueeze2(inputs): + a = (inputs[:, :, 0] + 2) * 7 + b = (inputs[:, :, 1] + 3) * 11 + c = (inputs[:, :, 2] + 5) * 13 + return paddle.concat([paddle.unsqueeze(t, axis=2) for t in [a, b, c]], axis=2) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(concat_unsqueeze1, input_data=input_data) + verify_model(concat_unsqueeze2, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_cumsum(): + @paddle.jit.to_static + def cusum1(inputs): + return paddle.cumsum(inputs) + + @paddle.jit.to_static + def cusum2(inputs): + return paddle.cumsum(inputs, axis=0) + + @paddle.jit.to_static + def cusum3(inputs): + return paddle.cumsum(inputs, axis=1) + + input_data = paddle.randint(0, 100, (10, 10), dtype=paddle.int32) + verify_model(cusum1, [input_data]) + verify_model(cusum1, [input_data.astype(paddle.int64)]) + verify_model(cusum2, [input_data, ]) + verify_model(cusum3, [input_data, ]) + +@tvm.testing.uses_gpu +def test_forward_conv(): + conv2d_input_shape = [1, 3, 10, 10] + + class Conv2D1(nn.Layer): + def __init__(self): + super(Conv2D1, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.softmax = nn.Softmax() + + @paddle.jit.to_static + def forward(self, inputs): + return self.softmax(self.conv(inputs)) + + class Conv2D2(nn.Layer): + def __init__(self): + super(Conv2D2, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) + self.softmax = nn.Softmax() + + @paddle.jit.to_static + def forward(self, inputs): + return self.softmax(self.conv(inputs)) + + conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=conv2d_input_data) + verify_model(Conv2D2(), input_data=conv2d_input_data) + +@tvm.testing.uses_gpu +def test_forward_dropout(): + @paddle.jit.to_static + def dropout(inputs): + return nn.functional.dropout(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(dropout, input_data=input_data[0, 0]) + verify_model(dropout, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_shape_full(): + @paddle.jit.to_static + def full1(inputs): + return paddle.full(paddle.shape(inputs), 3.14) + + @paddle.jit.to_static + def full2(inputs): + return paddle.full(paddle.shape(inputs), 1.0, dtype=inputs.dtype) + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(full1, input_data=[input_data]) + verify_model(full2, input_data=[input_data]) + +@tvm.testing.uses_gpu +def test_forward_ones_like(): + @paddle.jit.to_static + def ones_like1(inputs): + return paddle.ones_like(inputs) + + @paddle.jit.to_static + def ones_like2(inputs): + return paddle.ones_like(inputs, dtype="int32") + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ones_like1, input_data=input_data) + verify_model(ones_like2, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_gelu(): + @paddle.jit.to_static + def gelu(inputs): + return nn.functional.gelu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(gelu, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_hard_sigmoid(): + @paddle.jit.to_static + def hard_sigmoid(inputs): + return nn.functional.hardsigmoid(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_sigmoid, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_hard_swish(): + @paddle.jit.to_static + def hard_swish(inputs): + return nn.functional.hardswish(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_swish, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_layer_norm(): + @paddle.jit.to_static + def layer_norm(inputs, weight, bias): + return nn.functional.layer_norm(inputs, inputs.shape[-1], weight=weight, bias=bias) + + class LayerNorm(nn.Layer): + def __init__(self): + super(LayerNorm, self).__init__() + data_shape = [10] + self.layer_norm = nn.LayerNorm(data_shape) + + @paddle.jit.to_static + def forward(self, inputs): + return self.layer_norm(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + weight = paddle.rand([10], dtype="float32") + bias = paddle.rand([10], dtype="float32") + verify_model(layer_norm, input_data=[input_data, weight, bias]) + verify_model(LayerNorm(), input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_leaky_relu(): + @paddle.jit.to_static + def leaky_relu(inputs): + return nn.functional.leaky_relu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(leaky_relu, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_look_up(): + @paddle.jit.to_static + def look_up(inputs, weight): + return nn.functional.embedding(inputs, weight) + + class LookUp(nn.Layer): + def __init__(self): + super(LookUp, self).__init__() + self.embedding = paddle.nn.Embedding(10, 4, sparse=True) + + @paddle.jit.to_static + def forward(self, inputs): + return self.embedding(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.randint(0, 10, input_shape, dtype="int32") + weight = paddle.rand([10, 4], dtype="float32") + verify_model(look_up, input_data=[input_data, weight]) + verify_model(LookUp(), input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_multiply(): @paddle.jit.to_static def multiply1(inputs): return inputs * inputs @paddle.jit.to_static def multiply2(inputs): - return inputs * 1.0 + return inputs * 1.0 / 2.0 @paddle.jit.to_static - def multiply3(inputs): + def multiply3(inputs, inputs2): ones = paddle.ones([10], dtype="float32") - return inputs * ones + return inputs * ones / inputs2 + input_shape = [10] input_data = paddle.rand(input_shape, dtype="float32") verify_model(multiply1, input_data=input_data) verify_model(multiply2, input_data=input_data) - verify_model(multiply3, input_data=input_data) + input_data2 = paddle.rand(input_shape, dtype="float32") + verify_model(multiply3, input_data=[input_data, input_data2]) @tvm.testing.uses_gpu def test_forward_matmul(): @@ -161,36 +451,138 @@ def forward(self, input1, input2): verify_model(MatMul1(), input_data=[input_data1, input_data2]) @tvm.testing.uses_gpu -def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] +def test_forward_pool2d(): + @paddle.jit.to_static + def pool2d1(inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) - class Conv2D1(nn.Layer): - def __init__(self): - super(Conv2D1, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) - self.softmax = nn.Softmax() - - @paddle.jit.to_static - def forward(self, inputs): - return self.softmax(self.conv(inputs)) + @paddle.jit.to_static + def pool2d2(inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) - class Conv2D2(nn.Layer): - def __init__(self): - super(Conv2D2, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) - self.softmax = nn.Softmax() + @paddle.jit.to_static + def pool2d3(inputs): + return nn.functional.max_pool2d(inputs, kernel_size=2, stride=2, padding=0, return_mask=True) - @paddle.jit.to_static - def forward(self, inputs): - return self.softmax(self.conv(inputs)) + input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype='float32', min=-1, max=1) + verify_model(pool2d1, input_data=input_data) + verify_model(pool2d2, input_data=input_data) + # verify_model(pool2d3, input_data=input_data) - conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") - verify_model(Conv2D1(), input_data=conv2d_input_data) - verify_model(Conv2D2(), input_data=conv2d_input_data) +@tvm.testing.uses_gpu +def test_forward_relu(): + @paddle.jit.to_static + def relu(inputs): + return nn.functional.relu(inputs) + input_shape = [10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(relu, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_reshape(): + @paddle.jit.to_static + def reshape1(inputs, x): + new_shape = paddle.shape(x) + return paddle.reshape(inputs, new_shape) + + @paddle.jit.to_static + def reshape2(inputs): + return inputs.reshape([-1]) + + @paddle.jit.to_static + def reshape3(inputs): + data_shape = inputs.shape + return inputs.reshape([data_shape[0] * data_shape[1], data_shape[2]]) + + @paddle.jit.to_static + def reshape4(inputs, x): + new_shape = paddle.shape(x) + return paddle.reshape(inputs, [new_shape[2], 2, -1]) + + input_shape = [2, 1, 10, 1, 10] + input_data = paddle.rand(input_shape, dtype="float32") + input_data2 = paddle.randn([2, 1, 10, 10]) + verify_model(reshape1, input_data=[input_data, input_data2]) + verify_model(reshape2, input_data=input_data) + verify_model(reshape3, input_data=paddle.randn((2, 3, 4))) + verify_model(reshape4, input_data=[input_data, input_data2]) + +@tvm.testing.uses_gpu +def test_forward_scale(): + @paddle.jit.to_static + def scale1(inputs): + return paddle.scale(inputs, scale=2.0, bias=1.0) + + @paddle.jit.to_static + def scale2(inputs): + return paddle.scale(inputs, scale=3, bias=2.1, act="gelu") + + input_data = paddle.randn(shape=[2,3], dtype='float32') + verify_model(scale1, input_data=[input_data,]) + verify_model(scale2, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_slice(): + @paddle.jit.to_static + def slice1(inputs): + return inputs[:, :, :, :3] + + @paddle.jit.to_static + def slice2(inputs): + return inputs[0, :, :-3, :] + + @paddle.jit.to_static + def slice3(inputs): + return inputs[0::2, 0::2] + inputs[1::2, 1::2] + + @paddle.jit.to_static + def slice4(inputs): + x0 = paddle.to_tensor([2]) - paddle.to_tensor([1]) + x1 = paddle.to_tensor([3]) + paddle.to_tensor([1]) + return inputs[:, x0:, 1:x1, :] + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(slice1, input_data=[input_data,]) + verify_model(slice2, input_data=input_data) + # need op "strided_slice" + # verify_model(slice3, input_data=paddle.randn((4, 4))) + # need op "assign_value" + # verify_model(slice4, input_data=input_data) + +@tvm.testing.uses_gpu +def test_forward_tanh(): + @paddle.jit.to_static + def tanh(inputs): + return paddle.tanh(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(tanh, input_data=input_data) if __name__ == "__main__": test_forward_add_subtract() + test_forward_argmax() + test_forward_assign() + test_forward_batch_norm() + test_forward_cast() + test_forward_concat_unsqueeze() + test_forward_cumsum() + test_forward_conv() + test_forward_dropout() + test_forward_shape_full() + test_forward_ones_like() + test_forward_gelu() + test_forward_hard_sigmoid() + test_forward_hard_swish() + test_forward_layer_norm() + test_forward_leaky_relu() + test_forward_look_up() test_forward_multiply() test_forward_matmul() - test_forward_conv() + test_forward_pool2d() + test_forward_relu() + test_forward_reshape() + test_forward_scale() + test_forward_slice() + test_forward_tanh() From 1d599c1d2faf6284da38fdbc38fb2b29c419cb2a Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Wed, 11 Aug 2021 20:18:04 +0800 Subject: [PATCH 2/4] delete import pandas --- python/tvm/relay/frontend/paddlepaddle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 02f256ab3435..5308c960b838 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -19,7 +19,6 @@ """Paddle: PArallel Distributed Deep LEarning.""" import copy import warnings -from pandas.core.dtypes.inference import is_scalar import six import numpy as np From 75ef19b0f143e3986fcf13d7fae5286d94eec5da Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Fri, 13 Aug 2021 14:45:34 +0800 Subject: [PATCH 3/4] add paddlepaddle tests --- tests/scripts/task_python_frontend.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index fb388a6b7edd..d25d52438daa 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -54,3 +54,6 @@ run_pytest cython python-frontend-darknet tests/python/frontend/darknet echo "Running relay PyTorch frontend test..." run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch + +echo "Running relay PaddlePaddle frontend test..." +run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle From 67ca20da4b4c921f7a1986733107b429c2e869eb Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Fri, 13 Aug 2021 15:15:22 +0800 Subject: [PATCH 4/4] modify the variable name of convert_reshape --- python/tvm/relay/frontend/paddlepaddle.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 5308c960b838..dc2e94c896b8 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -610,14 +610,14 @@ def convert_pool2d(g, op, block): def convert_reshape(g, op, block): """Operator converter for reshape.""" - shape_attr = op.input('Shape') - tensor_attr = op.input('ShapeTensor') + input_shape = op.input('Shape') + input_shape_tensor = op.input('ShapeTensor') data = g.get_node(op.input('X')[0]) - if shape_attr: - new_shape = g.get_node(shape_attr[0]) - elif tensor_attr: + if input_shape: + new_shape = g.get_node(input_shape[0]) + elif input_shape_tensor: tmp_shape = [] - for shape_name in tensor_attr: + for shape_name in input_shape_tensor: shape = g.get_node(shape_name) if len(infer_shape(shape)) == 0: shape = _op.reshape(shape, [-1])