diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 62cfbc6c3b94..668c5d7b8261 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -25,7 +25,6 @@ from .. import analysis from .. import ty as _ty -from ... import nd as _nd from .. import expr as _expr from .. import function as _function from .. import ty as _ty @@ -370,7 +369,7 @@ def convert_expand(g, op, block): sizes = op.attr("shape") if isinstance(sizes, np.ndarray): - sizes = size.tolist() + sizes = sizes.tolist() out = _op.broadcast_to(x, sizes) g.add_node(op.output("Out")[0], out) @@ -457,8 +456,7 @@ def convert_hard_sigmoid(g, op, block): slope = op.attr("slope") x = g.get_node(op.input("X")[0]) - dtype = infer_type(x).checked_type.dtype - out = x * _expr.const(slope, dtype) + _expr.const(0.5, dtype) + out = x * _expr.const(slope) + _expr.const(0.5) out = _op.clip(out, 0, 1) g.add_node(op.output("Out")[0], out) @@ -473,9 +471,8 @@ def convert_hard_swish(g, op, block): assert np.isclose(scale, 6.0), "Only support scale==6.0 for PaddlePaddle's hard_swish" assert np.isclose(threshold, 6.0), "Only support threshold==6.0 for PaddlePaddle's hard_swish" x = g.get_node(op.input("X")[0]) - dtype = infer_type(x).checked_type.dtype out = _op.clip(x, -1 * offset, offset) - out = out / _expr.const(threshold, dtype) + _expr.const(0.5, dtype) + out = out / _expr.const(threshold) + _expr.const(0.5) out = x * out g.add_node(op.output("Out")[0], out) @@ -520,74 +517,19 @@ def convert_leaky_relu(g, op, block): g.add_node(op.output("Out")[0], out) -def convert_log1p(g, op, block): - """Operator converter for log1p.""" - - x = g.get_node(op.input("X")[0]) - dtype = infer_type(x).checked_type.dtype - one = _expr.const(1, dtype=dtype) - out = _op.log(x + one) - g.add_node(op.output("Out")[0], out) - - def convert_lookup_table(g, op, block): """Operator converter for lookup_table_v2.""" indices = g.get_node(op.input("Ids")[0]) padding_idx = op.attr("padding_idx") - weights = g.get_node(op.input("W")[0]) if padding_idx != -1: - if op.input("W")[0] in g.get_params(): - # while `w` is a parameter - weights = g.get_params(op.input("W")[0]) - weights[padding_idx] = 0.0 - weights = _expr.const(weights) - else: - # while `w` is a tensor - shape = try_infer_value(shape_of(weights), g.get_params())[0] - assert not isinstance( - shape, _expr.Expr - ), "Shape of weight has to be fixed for PaddlePaddle's lookup_table" - filters = np.ones(shape.tolist()).astype(infer_type(weights).checked_type.dtype) - filters[padding_idx] = 0.0 - filters = _expr.const(filters) - weights = weights * filters + g.get_params[op.input("W")[0]][padding_idx] = 0.0 + g.add_node(op.input("W")[0], _expr.const(g.params[op.input("W")[0]])) + weights = g.get_node(op.input("W")[0]) out = _op.take(weights, indices.astype("int32"), axis=0) g.add_node(op.output("Out")[0], out) -def convert_logical_not(g, op, block): - """Operator converter for logical_not op.""" - - ipt0 = g.get_node(op.input("X")[0]) - op_func = get_relay_op(op.type) - out = op_func(ipt0) - g.add_node(op.output("Out")[0], out) - - -def convert_logsigmoid(g, op, block): - """Operator converter for logsigmoid.""" - - x = g.get_node(op.input("X")[0]) - out = _op.log(_op.tensor.sigmoid(x)) - g.add_node(op.output("Out")[0], out) - - -def convert_logsoftmax(g, op, block): - """Operator converter for logsoftmax.""" - - x = g.get_node(op.input("X")[0]) - axis = op.attr("axis") - ndim = len(infer_shape(x)) - if axis < 0: - axis += ndim - m = _op.max(x, [axis], keepdims=True) - e = _op.exp(x - m) - s = _op.sum(e, [axis], keepdims=True) - out = x - m - _op.log(s) - g.add_node(op.output("Out")[0], out) - - def convert_matmul(g, op, block): """Operator converter for matmul.""" @@ -697,16 +639,6 @@ def flatten_to_nd(x, x_shape, nd=3): g.add_node(op.output("Out")[0], out) -def convert_meshgrid(g, op, block): - """Operator converter for meshgrid.""" - - inputs = op.input("X") - x = [g.get_node(i) for i in inputs] - outs = _op.meshgrid(x, indexing="ij") - for i, out in enumerate(outs): - g.add_node(op.output("Out")[i], out) - - def convert_mul(g, op, block): """Operator converter for mul.""" @@ -714,8 +646,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = _op.shape_of(x) - y_shape = _op.shape_of(y) + x_shape = shape_of(x) + y_shape = shape_of(y) x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -752,15 +684,6 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) -def convert_numel(g, op, block): - """Operator converter for numel.""" - - input_x = g.get_node(op.input("Input")[0]) - out = _op.ndarray_size(input_x, dtype="int64") - out = _op.expand_dims(out, axis=0) - g.add_node(op.output("Out")[0], out) - - def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -776,7 +699,7 @@ def convert_pool2d(g, op, block): ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - _, _, in_h, in_w = infer_shape(input_x) + in_h, in_w = infer_shape(input_x)[2:] op_map = { "avg": "avg_pool2d", @@ -793,19 +716,8 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - if strides[0] == 1 and strides[1] == 1: - pad_h = _get_pad_size(0, ksize[0], strides[0]) - pad_w = _get_pad_size(0, ksize[1], strides[1]) - else: - input_shape = shape_of(input_x) - h_w = _op.strided_slice(input_shape, [2], [4]) - try: - in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() - except Exception as e: - msg = "Dynamic shape is not supported in SAME padding algorithm while stride!=1" - raise tvm.error.OpAttributeInvalid(msg) from e - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) + pad_h = _get_pad_size(in_h, ksize[0], strides[0]) + pad_w = _get_pad_size(in_w, ksize[1], strides[1]) paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -816,11 +728,6 @@ def convert_pool2d(g, op, block): msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) - if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: - ksize[0] = in_h - if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: - ksize[1] = in_w - if not adaptive: out = getattr(_op.nn, op_map[pooling_type])( input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode @@ -837,26 +744,22 @@ def convert_reshape(g, op, block): input_shape_tensor = op.input("ShapeTensor") data = g.get_node(op.input("X")[0]) if input_shape: - # if the target shape is a 1D tensor new_shape = g.get_node(input_shape[0]) elif input_shape_tensor: - # if the target shape is a list of tensors - new_shape = [] + tmp_shape = [] for shape_name in input_shape_tensor: shape = g.get_node(shape_name) if len(infer_shape(shape)) == 0: - # sometimes the element maybe a scalar tensor shape = _op.reshape(shape, [-1]) - new_shape.append(shape.astype("int64")) - new_shape = _op.concatenate(new_shape, axis=0) - new_shape = try_infer_value(new_shape, g.get_params())[0] + if isinstance(shape, _expr.Constant): + tmp_shape.append(shape) + elif isinstance(shape, _expr.Expr): + tmp_shape.append(shape) + else: + tmp_shape.append(_expr.const(np.array(shape).astype("int64"))) + new_shape = _op.concatenate(tmp_shape, axis=0) else: - # if the target shape is a list of constant value new_shape = op.attr("shape") - - if isinstance(new_shape, np.ndarray): - new_shape = new_shape.tolist() - out = _op.reshape(data, new_shape) g.add_node(op.output("Out")[0], out) @@ -869,11 +772,8 @@ def convert_scale(g, op, block): bias_after_scale = op.attr("bias_after_scale") x = g.get_node(op.input("X")[0]) if np.isclose(scale, 1.0) and np.isclose(bias, 0.0): - out = x + out = _op.copy(x) else: - x_dtype = infer_type(x).checked_type.dtype - if x_dtype != "float32": - x = x.astype("float32") if np.isclose(bias, 0.0): out = x * _expr.const(np.array(scale).astype("float32")) elif np.isclose(scale, 1.0): @@ -887,8 +787,6 @@ def convert_scale(g, op, block): out = (x + _expr.const(np.array(bias).astype("float32"))) * _expr.const( np.array(scale).astype("float32") ) - if x_dtype != "float32": - out = out.astype(x_dtype) g.add_node(op.output("Out")[0], out) @@ -903,128 +801,39 @@ def convert_shape(g, op, block): def convert_slice(g, op, block): """Operator converter for slice.""" - data = g.get_node(op.input("Input")[0]) - dims = len(infer_shape(data)) + def parameter_process(starts, ends, axes, dshape): + new_axes = [] + new_starts = [] + new_ends = [] + pop_index = 0 + for i in range(max(axes) + 1): + new_axes.append(i) + if i in axes: + new_starts.append(starts[pop_index]) + new_ends.append(ends[pop_index]) + pop_index += 1 + else: + new_starts.append(0) + new_ends.append(dshape[i]) + return new_starts, new_ends, new_axes + data = g.get_node(op.input("Input")[0]) + dshape = infer_shape(data) + starts = op.attr("starts") + ends = op.attr("ends") axes = op.attr("axes") - indices = _expr.const(axes, dtype="int64") - decrease_axis = op.attr("decrease_axis") + if isinstance(starts, int): + starts = [starts] + if isinstance(ends, int): + ends = [ends] + if isinstance(axes, int): + axes = [axes] if isinstance(decrease_axis, int): decrease_axis = [decrease_axis] - - if op.input("StartsTensor"): - # if `starts` is a 1D tensor - starts = g.get_node(op.input("StartsTensor")[0]) - starts = try_infer_value(starts, g.get_params())[0] - elif op.input("StartsTensorList"): - # if `starts` is a list of tensors - starts = [] - for start_index in op.input("StartsTensorList"): - start_index = g.get_node(start_index).astype("int64") - starts.append(start_index) - starts = _op.concatenate(starts, axis=0) - starts = try_infer_value(starts, g.get_params())[0] - else: - # if `starts` is a list of constant values - starts = op.attr("starts") - - if isinstance(starts, np.ndarray): - starts = starts.tolist() - - if len(axes) < dims: - if isinstance(starts, _expr.Expr): - starts = _op.scatter( - _op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype), - indices, - starts, - axis=0, - ) - else: - base = [0] * dims - for i, axis in enumerate(axes): - base[axis] = starts[i] - starts = base - - if op.input("EndsTensor"): - # if `ends` is a 1D tensor - ends = g.get_node(op.input("EndsTensor")[0]) - ends = try_infer_value(ends, g.get_params())[0] - elif op.input("EndsTensorList"): - # if `ends` is a list of tensors - ends = [] - for end_index in op.input("EndsTensorList"): - end_index = g.get_node(end_index).astype("int64") - ends.append(end_index) - ends = _op.concatenate(ends, axis=0) - ends = try_infer_value(ends, g.get_params())[0] - else: - # if `ends` is a list of constant values - ends = op.attr("ends") - - if isinstance(ends, np.ndarray): - ends = ends.tolist() - - if len(axes) < dims: - if isinstance(ends, _expr.Expr): - ends = _op.scatter( - _expr.const( - np.array([np.iinfo(np.int32).max] * dims), - dtype=infer_type(ends).checked_type.dtype, - ), - indices, - ends, - axis=0, - ) - else: - base = [np.iinfo(np.int32).max] * dims - for i, axis in enumerate(axes): - base[axis] = ends[i] - ends = base - - strides = None - if "StridesTensor" in op.input_names and op.input("StridesTensor"): - # if `strides` is a 1D tensor - strides = g.get_node(op.input("StridesTensor")[0]) - strides = try_infer_value(strides, g.get_params())[0] - elif "StridesTensorList" in op.input_names and op.input("StridesTensorList"): - # if `strides` is a list of tensors - strides = [] - for strides_index in op.input("StridesTensorList"): - strides_index = g.get_node(strides_index).astype("int64") - strides.append(strides_index) - strides = _op.concatenate(strides, axis=0) - strides = try_infer_value(strides, g.get_params())[0] - elif op.has_attr("strides"): - # if `strides` is a list of constant values - strides = op.attr("strides") - - if isinstance(strides, np.ndarray): - strides = strides.tolist() - - if len(axes) < dims: - if isinstance(strides, _expr.Expr): - strides = _op.scatter( - _expr.const( - np.array([1] * dims), - dtype=infer_type(strides).checked_type.dtype, - ), - indices, - strides, - axis=0, - ) - elif strides: - base = [1] * dims - for i, axis in enumerate(axes): - base[axis] = strides[i] - strides = base - if not strides: - strides = _op.const([1] * dims, dtype="int64") - - out = _op.strided_slice(data, begin=starts, end=ends, strides=strides) + starts, ends, axes = parameter_process(starts, ends, axes, dshape) + out = _op.strided_slice(data, begin=starts, end=ends) if decrease_axis: - # `decrease_axis` is False while using paddle.slice() - # `decrease_axis` is True while using tensor[1:2] out = _op.squeeze(out, axis=decrease_axis) g.add_node(op.output("Out")[0], out) @@ -1054,22 +863,15 @@ def convert_unsqueeze(g, op, block): _convert_map = { - "abs": convert_unary_op, - "acos": convert_unary_op, "arg_max": convert_arg_max_min, "arg_min": convert_arg_max_min, "argsort": convert_argsort, - "asin": convert_unary_op, "assign": convert_assign, "assign_value": convert_assign_value, - "atan": convert_unary_op, "batch_norm": convert_batch_norm, "cast": convert_cast, - "ceil": convert_unary_op, "concat": convert_concat, "conv2d": convert_conv2d, - "cos": convert_unary_op, - "cosh": convert_unary_op, "cumsum": convert_cumsum, "depthwise_conv2d": convert_conv2d, "dot": convert_dot, @@ -1078,23 +880,14 @@ def convert_unsqueeze(g, op, block): "elementwise_div": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, - "elementwise_mod": convert_elementwise_op, - "elementwise_max": convert_elementwise_op, - "elementwise_min": convert_elementwise_op, - "elementwise_pow": convert_elementwise_op, - "elementwise_floordiv": convert_elementwise_op, "equal": convert_elementwise_op, - "erf": convert_unary_op, "exp": convert_unary_op, "expand_v2": convert_expand, "expand_as_v2": convert_expand_as, "feed": convert_feed, "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, - "floor": convert_unary_op, "gelu": convert_gelu, - "greater_equal": convert_elementwise_op, - "greater_than": convert_elementwise_op, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, "isfinite_v2": convert_unary_op, @@ -1102,41 +895,20 @@ def convert_unsqueeze(g, op, block): "isnan_v2": convert_unary_op, "layer_norm": convert_layer_norm, "leaky_relu": convert_leaky_relu, - "less_equal": convert_elementwise_op, - "less_than": convert_elementwise_op, - "log": convert_unary_op, - "log2": convert_unary_op, - "log10": convert_unary_op, - "log1p": convert_log1p, "logical_and": convert_binary_logical_op, - "logical_not": convert_logical_not, "logical_or": convert_binary_logical_op, "logical_xor": convert_binary_logical_op, - "logsigmoid": convert_logsigmoid, - "log_softmax": convert_logsoftmax, "lookup_table_v2": convert_lookup_table, "matmul": convert_matmul, "matmul_v2": convert_matmul, - "meshgrid": convert_meshgrid, "mul": convert_mul, - "not_equal": convert_elementwise_op, "pool2d": convert_pool2d, "relu": convert_unary_op, "reshape2": convert_reshape, - "round": convert_unary_op, - "rsqrt": convert_unary_op, "scale": convert_scale, "shape": convert_shape, - "sigmoid": convert_unary_op, - "sign": convert_unary_op, - "sin": convert_unary_op, - "sinh": convert_unary_op, - "size": convert_numel, "slice": convert_slice, "softmax": convert_softmax, - "sqrt": convert_unary_op, - "strided_slice": convert_slice, - "tan": convert_unary_op, "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } @@ -1145,38 +917,30 @@ def convert_unsqueeze(g, op, block): class GraphProto: """A helper class for handling relay functions from PaddlePaddle model.""" - def __init__(self, freeze_params=False): + def __init__(self): self.nodes = {} self.params = {} self.shape_dict = None - self.freeze_params = freeze_params def get_node(self, name): """get node from graph""" - assert name in self.nodes, "Node: {} not found".format(name) + assert name in self.nodes return self.nodes[name] def add_node(self, name, node): """add a node to graph""" - if self.shape_dict: - self.nodes[name] = fold_constant(node) - else: - self.nodes[name] = node + + self.nodes[name] = fold_constant(node) def get_params(self, name=None): - """get params from graph""" + """Get params from graph.""" if name is None: return self.params assert name in self.params return self.params[name] - def set_params(self, params): - """set params for graph""" - - self.params = params - def extract_parameters(self, program, scope=None): """Extract all the weights from PaddlePaddle program.""" @@ -1191,13 +955,21 @@ def extract_parameters(self, program, scope=None): if isinstance(scope, dict): self.params[name] = scope[name] else: - self.params[name] = _nd.array(np.array(scope.var(name).get_tensor())) - if self.freeze_params: - self.nodes[name] = _expr.const(self.params[name]) - else: - self.nodes[name] = _expr.var( - name, shape=self.params[name].shape, dtype=str(self.params[name].dtype) + self.params[name] = np.array(scope.var(name).get_tensor()) + self.nodes[name] = _expr.const(self.params[name]) + + def check_input_shape(self, op, block): + """Check the shape information of model's inputs, fixed shape is recommended.""" + + ipt_name = op.input(op.input_names[0]) + ipt_shape = block.var(ipt_name).shape + for i in ipt_shape: + if i < 0: + warning_msg = "Input {}(shape={}) has unkown dimension shapes. \ + Specifying static values may improve performance".format( + ipt_name, ipt_shape ) + warnings.warn(warning_msg) def check_unsupported_ops(self, program): """Check whether all the operators are supported.""" @@ -1206,8 +978,6 @@ def check_unsupported_ops(self, program): for block in program.blocks: for op in block.ops: if op.type == "fetch": - # `fetch` is a flag of output tensors - # there's no need to handle this continue if op.type not in _convert_map: unsupported_ops.add(op.type) @@ -1222,12 +992,12 @@ def ops_to_relay(self, program, input_specs=None): if input_specs is not None: for input_spec in input_specs: convert_feed(self, input_spec, None) - global_block = program.blocks[0] - for op in global_block.ops: - if op.type == "fetch": - continue - convert_func = _convert_map[op.type] - convert_func(self, op, global_block) + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + continue + convert_func = _convert_map[op.type] + convert_func(self, op, block) def from_program(self, program, shape_dict, scope): """Construct the TVM relay expression from PaddlePaddle program.""" @@ -1247,14 +1017,12 @@ def from_program(self, program, shape_dict, scope): if op.type == "fetch": output_names.append(op.input("X")[0]) - outputs = [self.get_node(name) for name in output_names] + outputs = [self.nodes[name] for name in output_names] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) - if self.freeze_params: - self.params = {} return mod, self.params def from_translated_layer(self, layer, shape_dict): @@ -1273,53 +1041,25 @@ def from_translated_layer(self, layer, shape_dict): output_names = [x.name for x in layer._output_spec()] - outputs = [self.get_node(name) for name in output_names] + outputs = [self.nodes[name] for name in output_names] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) - if self.freeze_params: - self.params = {} return mod, self.params -def from_paddle(program_or_layer, shape_dict=None, scope=None, freeze_params=False): +def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. - PaddlePaddle Program/TranslatedLayer represent the computation - graph of PaddlePaddle model, and PaddlePaddle scope stores all the - weights of PaddlePaddle model. - - Parameters - ---------- - program_or_layer : Program/TranslatedLayer object - Loaded model by `paddle.static.load_inference_model` or `paddle.jit.load` - - shape_dict : dict of str to tuple, optional - The input shape to the model - - scope : Scope object, optional - All the weights saved in scope, by default, use `paddle.fluid.global_scope` - - freeze_params : bool - If this parameter is true, the importer will take any provided weights and - embed them into the relay model as Constants instead of variables. This - allows more aggressive optimizations at compile time and helps in making - models static if certain inputs represent attributes relay would traditionally - consider compile-time constants. - - Returns - ------- - mod : tvm.IRModule - The relay module for compilation - - params : dict of str to tvm.nd.NDArray - The parameter dict to be used by relay + + PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, + and PaddlePaddle scope stores all the weights of PaddlePaddle model. """ import paddle - g = GraphProto(freeze_params) + g = GraphProto() if isinstance(program_or_layer, paddle.jit.TranslatedLayer): # model is loaded by `paddle.jit.load` mod, params = g.from_translated_layer(program_or_layer, shape_dict) diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 418e22fd4409..ac6cd0ed94d9 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -14,21 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os from pathlib import Path import shutil import numpy as np - -import paddle -from paddle.framework import dtype -import paddle.nn as nn - import tvm import tvm.testing import tvm.topi.testing from tvm import relay from tvm.contrib import graph_executor +import paddle +import paddle.nn as nn PADDLE_TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data", "paddle") PADDLE_TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) @@ -36,16 +34,14 @@ def assert_shapes_match(tru, est): if tru.shape != est.shape: - msg = "Paddle Output shapes {} and TVM shapes {} don't match" + msg = "Output shapes {} and {} don't match" raise AssertionError(msg.format(tru.shape, est.shape)) - if tru.dtype != est.dtype: - msg = "Paddle Output dtype {} and TVM dtype {} don't match" - raise AssertionError(msg.format(tru.dtype, est.dtype)) def get_paddle_model(func, input_spec): global PADDLE_TEST_DATA_ROOT_PATH model_path = Path(PADDLE_TEST_DATA_ROOT_PATH, "model") + paddle.jit.save(func, str(model_path), input_spec=input_spec) baseline_model = paddle.jit.load(str(model_path)) @@ -53,34 +49,7 @@ def get_paddle_model(func, input_spec): return baseline_model -def get_tvm_output_with_vm(mod, params, target, device, input_data): - """Generic function to execute and get tvm output with vm executor""" - - ex = relay.create_executor("vm", mod=mod, device=device, target=target) - params.update(input_data) - result = ex.evaluate()(**params) - if isinstance(result, tvm.runtime.NDArray): - return [ - result.numpy(), - ] - return [r.numpy() for r in result] - - -def get_tvm_output(mod, params, target, device, input_data, compiled_names, num): - """Generic function to execute and get tvm output""" - - lib = relay.build(mod, target=target, params=params) - gmod = graph_executor.GraphModule(lib["default"](device)) - for name in compiled_names: - gmod.set_input(name, input_data[name]) - gmod.run() - outputs = [] - for i in range(num): - outputs.append(gmod.get_output(i).numpy()) - return outputs - - -def verify_model(func, input_data, rtol=1e-5, atol=1e-5, input_shape=None): +def verify_model(func, input_data, rtol=1e-5, atol=1e-5): if not (isinstance(input_data, (tuple, list))): input_data = [input_data] @@ -90,13 +59,11 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5, input_shape=None): compiled_input = {} for idx, data in enumerate(input_data): input_name = "input{}".format(idx) - if input_shape: - shape = input_shape[idx] - else: - shape = data.shape - input_shape_dict[input_name] = shape - input_spec.append(paddle.static.InputSpec(dtype=data.dtype, shape=shape, name=input_name)) + input_spec.append( + paddle.static.InputSpec(dtype=data.dtype, shape=data.shape, name=input_name) + ) input_names.append(input_name) + input_shape_dict[input_name] = data.shape if isinstance(data, np.ndarray): compiled_input[input_name] = data else: @@ -114,70 +81,24 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5, input_shape=None): mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) parms_num = min(len(input_names), len(mod["main"].params)) compiled_names = [] - for arg in mod["main"].params: + for arg in mod["main"].params[:parms_num]: assert arg.name_hint in input_names or arg.name_hint in params if arg.name_hint in input_names: compiled_names.append(arg.name_hint) with tvm.transform.PassContext(opt_level=3): for target, dev in tvm.testing.enabled_targets(): - if input_shape: - tvm_output = get_tvm_output_with_vm(mod, params, target, dev, compiled_input) - else: - tvm_output = get_tvm_output( - mod, params, target, dev, compiled_input, compiled_names, len(baseline_outputs) - ) - - for baseline_output, compiled_output in zip(baseline_outputs, tvm_output): - assert_shapes_match(baseline_output, compiled_output) - tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) + lib = relay.build(mod, target=target, params=params) + gmod = graph_executor.GraphModule(lib["default"](dev)) + for name in compiled_names: + gmod.set_input(name, compiled_input[name]) + gmod.run() + for i, baseline_output in enumerate(baseline_outputs): + compiled_output = gmod.get_output(i).numpy() -@tvm.testing.uses_gpu -def test_forward_math(): - class MathAPI(nn.Layer): - def __init__(self, api_name): - super(MathAPI, self).__init__() - for candidate in (paddle, paddle.nn.functional): - self.func = getattr(candidate, api_name, None) - if self.func: - break - - @paddle.jit.to_static - def forward(self, inputs): - return self.func(inputs) - - input_data = paddle.rand([1, 2, 5, 5], dtype="float32") - api_list = [ - "abs", - "acos", - "asin", - "atan", - "ceil", - "cos", - "cosh", - "erf", - "exp", - "floor", - "log", - "log2", - "log10", - "log1p", - "numel", - "relu", - "round", - "rsqrt", - "sigmoid", - "sign", - "rsqrt", - "sin", - "sinh", - "sqrt", - "tan", - "tanh", - ] - for api_name in api_list: - verify_model(MathAPI(api_name), input_data) + assert_shapes_match(baseline_output, compiled_output) + tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) @tvm.testing.uses_gpu @@ -205,7 +126,7 @@ def add_subtract3(inputs1, inputs2): @tvm.testing.uses_gpu -def test_forward_argmax(): +def test_forward_arg_max_min(): input_shape = [1, 3, 10, 10] class ArgMax(nn.Layer): @@ -228,17 +149,6 @@ class ArgMax3(nn.Layer): def forward(self, inputs): return inputs.argmax(axis=2, keepdim=True) - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(ArgMax(), input_data=input_data) - verify_model(ArgMax1(), input_data=input_data) - verify_model(ArgMax2(), input_data=input_data) - verify_model(ArgMax3(), input_data=input_data) - - -@tvm.testing.uses_gpu -def test_forward_argmin(): - input_shape = [1, 3, 10, 10] - class ArgMin(nn.Layer): @paddle.jit.to_static def forward(self, inputs): @@ -252,7 +162,7 @@ def forward(self, inputs): class ArgMin2(nn.Layer): @paddle.jit.to_static def forward(self, inputs): - return inputs.argmin(axis=1, keepdim=False) + return inputs.argmax(axis=1, keepdim=False) class ArgMin3(nn.Layer): @paddle.jit.to_static @@ -260,6 +170,10 @@ def forward(self, inputs): return inputs.argmin(axis=2, keepdim=True) input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax(), input_data=input_data) + verify_model(ArgMax1(), input_data=input_data) + verify_model(ArgMax2(), input_data=input_data) + verify_model(ArgMax3(), input_data=input_data) verify_model(ArgMin(), input_data=input_data) verify_model(ArgMin1(), input_data=input_data) verify_model(ArgMin2(), input_data=input_data) @@ -285,16 +199,31 @@ def argsort2(inputs): @tvm.testing.uses_gpu def test_forward_assign(): - class Assign(nn.Layer): - @paddle.jit.to_static - def forward(self, inputs): - return paddle.assign(inputs) + @paddle.jit.to_static + def assign(inputs): + return paddle.assign(inputs) + + @paddle.jit.to_static + def assign_value(inputs): + x = paddle.to_tensor(np.array([3]).astype("float32")) + return inputs + x input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(Assign(), [input_data]) + verify_model( + assign, + [ + input_data, + ], + ) input_data2 = np.random.randint(100, size=input_shape) - verify_model(Assign(), [input_data2], input_shape=[[-1, -1]]) + verify_model( + assign, + [ + input_data2, + ], + ) + verify_model(assign_value, [input_data]) @tvm.testing.uses_gpu @@ -346,8 +275,39 @@ def cast2(inputs, dtype="int64"): input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") * 100 - verify_model(cast1, [input_data]) - verify_model(cast2, [input_data]) + verify_model( + cast1, + [ + input_data, + ], + ) + verify_model( + cast2, + [ + input_data, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_check_tensor(): + @paddle.jit.to_static + def isfinite(inputs): + return paddle.cast(paddle.isfinite(inputs), "int32") + + @paddle.jit.to_static + def isnan(inputs): + return paddle.cast(paddle.isnan(inputs), "int32") + + @paddle.jit.to_static + def isinf(inputs): + return paddle.cast(paddle.isinf(inputs), "int32") + + input_shape = [5, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(isfinite, input_data=input_data) + verify_model(isnan, input_data=input_data) + verify_model(isinf, input_data=input_data) @tvm.testing.uses_gpu @@ -371,30 +331,39 @@ def concat_unsqueeze2(inputs): @tvm.testing.uses_gpu def test_forward_cumsum(): - class Cumsum1(nn.Layer): - @paddle.jit.to_static - def forward(self, inputs): - return paddle.cumsum(inputs) + @paddle.jit.to_static + def cusum1(inputs): + return paddle.cumsum(inputs) - class Cumsum2(nn.Layer): - @paddle.jit.to_static - def forward(self, inputs): - return paddle.cumsum(inputs, axis=0) + @paddle.jit.to_static + def cusum2(inputs): + return paddle.cumsum(inputs, axis=0) - class Cumsum3(nn.Layer): - @paddle.jit.to_static - def forward(self, inputs): - return paddle.cumsum(inputs, axis=1) + @paddle.jit.to_static + def cusum3(inputs): + return paddle.cumsum(inputs, axis=1) input_data = paddle.randint(0, 100, (10, 10), dtype=paddle.int32) - verify_model(Cumsum1(), input_data) - verify_model(Cumsum1(), [input_data.astype(paddle.int64)]) - verify_model(Cumsum2(), input_data) - verify_model(Cumsum3(), input_data) + verify_model(cusum1, [input_data]) + verify_model(cusum1, [input_data.astype(paddle.int64)]) + verify_model( + cusum2, + [ + input_data, + ], + ) + verify_model( + cusum3, + [ + input_data, + ], + ) @tvm.testing.uses_gpu def test_forward_conv(): + conv2d_input_shape = [1, 3, 10, 10] + class Conv2D1(nn.Layer): def __init__(self): super(Conv2D1, self).__init__() @@ -415,51 +384,22 @@ def __init__(self): def forward(self, inputs): return self.softmax(self.conv(inputs)) - class Conv2D3(nn.Layer): - def __init__(self): - super(Conv2D3, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False, padding="SAME") - - @paddle.jit.to_static - def forward(self, inputs): - return self.conv(inputs) - - class Conv2D4(nn.Layer): - def __init__(self): - super(Conv2D4, self).__init__() - self.conv = nn.Conv2D( - 3, 6, 7, groups=3, bias_attr=False, padding=[1, 2, 0, 1], stride=2, dilation=2 - ) - - @paddle.jit.to_static - def forward(self, inputs): - return self.conv(inputs) - - conv2d_input_shape = [1, 3, 112, 112] conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") verify_model(Conv2D1(), input_data=conv2d_input_data) verify_model(Conv2D2(), input_data=conv2d_input_data) - verify_model(Conv2D3(), input_data=conv2d_input_data) - verify_model(Conv2D4(), input_data=conv2d_input_data) - verify_model(Conv2D1(), conv2d_input_data, input_shape=[[-1, 3, 112, 112]]) @tvm.testing.uses_gpu def test_forward_dot(): @paddle.jit.to_static - def dot1(x, y): - return paddle.dot(x, y) - - @paddle.jit.to_static - def dot2(x): - y = paddle.to_tensor(np.random.rand(10).astype("float32")) + def dot(x, y): return paddle.dot(x, y) - x_data = paddle.rand([10, 3], dtype="float32") - y_data = paddle.rand([10, 3], dtype="float32") - verify_model(dot1, input_data=[x_data, y_data]) - x_data = paddle.rand([10], dtype="float32") - verify_model(dot2, input_data=[x_data]) + x_shape = [10, 3] + y_shape = [10, 3] + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(dot, input_data=[x_data, y_data]) @tvm.testing.uses_gpu @@ -474,6 +414,37 @@ def dropout(inputs): verify_model(dropout, input_data=input_data) +def test_forward_elemwise(): + class ElemwiseAPI(nn.Layer): + def __init__(self, api_name): + super(ElemwiseAPI, self).__init__() + self.api_name_ = api_name + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, api_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, input1, input2): + y = self.func(input1, input2) + if "equal" in self.api_name_ or "than" in self.api_name_: + # for compare operation, cast boolean result to int32 + y = paddle.cast(y, "int32") + return y + + api_list = [ + "equal", + ] + input_shape = [10, 10] + input_shape_2 = [ + 10, + ] + x_data = paddle.randint(1, 10, input_shape, dtype="int32") + y_data = paddle.randint(1, 10, input_shape_2, dtype="int32") + for api_name in api_list: + verify_model(ElemwiseAPI(api_name), [x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_expand(): @paddle.jit.to_static @@ -481,14 +452,14 @@ def expand1(inputs): return paddle.expand(inputs, shape=[2, 3]) @paddle.jit.to_static - def expand2(inputs, shape): + def expand2(inputs): + shape = paddle.to_tensor(np.array([2, 3]).astype("int32")) return paddle.expand(inputs, shape=shape) x_shape = [3] x_data = paddle.rand(x_shape, dtype="float32") verify_model(expand1, input_data=[x_data]) - shape = paddle.to_tensor(np.array([2, 3]).astype("int32")) - verify_model(expand2, [x_data, shape], input_shape=[[3], [2]]) + verify_model(expand2, input_data=[x_data]) @tvm.testing.uses_gpu @@ -508,21 +479,15 @@ def expand_as(x, y): def test_forward_shape_full(): @paddle.jit.to_static def full1(inputs): - return paddle.full(inputs, 3.14) + return paddle.full(paddle.shape(inputs), 3.14) @paddle.jit.to_static def full2(inputs): return paddle.full(paddle.shape(inputs), 1.0, dtype=inputs.dtype) - @paddle.jit.to_static - def shape1(inputs): - return paddle.shape(inputs) - input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(shape1, input_data=[input_data]) - shape = paddle.to_tensor(np.array(input_shape, "int32")) - verify_model(full1, input_data=[shape], input_shape=[[4]]) + verify_model(full1, input_data=[input_data]) verify_model(full2, input_data=[input_data]) @@ -539,71 +504,7 @@ def ones_like2(inputs): input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") verify_model(ones_like1, input_data=input_data) - verify_model(ones_like2, input_data, input_shape=[[-1, -1, -1, -1]]) - - -@tvm.testing.uses_gpu -def test_forward_ones(): - @paddle.jit.to_static - def ones1(inputs): - ones = paddle.ones([1, 3, 10, 10]) - out = inputs + ones - return out - - @paddle.jit.to_static - def ones2(inputs): - shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32") - ones = paddle.ones(shape) - out = inputs + ones - return out - - input_shape = [1, 3, 10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(ones1, input_data=input_data) - verify_model(ones2, input_data=input_data) - - -def test_forward_elemwise(): - class ElemwiseAPI(nn.Layer): - def __init__(self, api_name): - super(ElemwiseAPI, self).__init__() - self.api_name_ = api_name - for candidate in (paddle, paddle.nn.functional): - self.func = getattr(candidate, api_name, None) - if self.func: - break - - @paddle.jit.to_static - def forward(self, input1, input2): - y = self.func(input1, input2) - if "equal" in self.api_name_ or "than" in self.api_name_: - y = paddle.cast(y, "int32") - return y - - api_list = [ - "floor_divide", - "floor_mod", - "maximum", - "minimum", - "equal", - "greater_equal", - "greater_than", - "less_equal", - "less_than", - "not_equal", - ] - input_shape = [10, 10] - input_shape_2 = [ - 10, - ] - x_data = paddle.rand(input_shape, dtype="float32") - y_data = paddle.rand(input_shape_2, dtype="float32") - x_data_2 = paddle.randint(1, 100, input_shape_2, dtype="int32") - y_data_2 = paddle.randint(1, 100, input_shape, dtype="int32") - for api_name in api_list: - if api_name not in ["floor_divide"]: - verify_model(ElemwiseAPI(api_name), [x_data, y_data]) - verify_model(ElemwiseAPI(api_name), [x_data_2, y_data_2]) + verify_model(ones_like2, input_data=input_data) @tvm.testing.uses_gpu @@ -618,54 +519,25 @@ def gelu(inputs): @tvm.testing.uses_gpu -def test_forward_activation(): - class Activation(nn.Layer): - def __init__(self, op_name): - super(Activation, self).__init__() - self.op_name_ = op_name - for candidate in (paddle.nn.functional, paddle): - self.func = getattr(candidate, op_name, None) - if self.func: - break - - @paddle.jit.to_static - def forward(self, inputs): - return self.func(inputs) +def test_forward_hard_sigmoid(): + @paddle.jit.to_static + def hard_sigmoid(inputs): + return nn.functional.hardsigmoid(inputs) input_shape = [1, 3, 10, 10] - input_data = paddle.normal(shape=input_shape) * 10.0 - input_data_2 = paddle.normal(shape=input_shape).astype("float64") * 10.0 - op_list = [ - "hardsigmoid", - "hardswish", - "leaky_relu", - "log_sigmoid", - "log_softmax", - "sigmoid", - ] - for op_name in op_list: - verify_model(Activation(op_name), input_data=input_data) + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_sigmoid, input_data=input_data) @tvm.testing.uses_gpu -def test_forward_check_tensor(): - @paddle.jit.to_static - def isfinite(inputs): - return paddle.cast(paddle.isfinite(inputs), "int32") - +def test_forward_hard_swish(): @paddle.jit.to_static - def isinf(inputs): - return paddle.cast(paddle.isinf(inputs), "int32") - - @paddle.jit.to_static - def isnan(inputs): - return paddle.cast(paddle.isnan(inputs), "int32") + def hard_swish(inputs): + return nn.functional.hardswish(inputs) - input_shape = [5, 5] + input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(isfinite, input_data=input_data) - verify_model(isinf, input_data=input_data) - verify_model(isnan, input_data=input_data) + verify_model(hard_swish, input_data=input_data) @tvm.testing.uses_gpu @@ -693,47 +565,37 @@ def forward(self, inputs): @tvm.testing.uses_gpu -def test_forward_logical_op(): - class LogicalOp(nn.Layer): - def __init__(self, op_name, out=False): - super(LogicalOp, self).__init__() - self.out = out +def test_forward_leaky_relu(): + @paddle.jit.to_static + def leaky_relu(inputs): + return nn.functional.leaky_relu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(leaky_relu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_logical_api(): + class LogicalAPI(nn.Layer): + def __init__(self, api_name): + super(LogicalAPI, self).__init__() for candidate in (paddle, paddle.nn.functional): - self.func = getattr(candidate, op_name, None) + self.func = getattr(candidate, api_name, None) if self.func: break @paddle.jit.to_static def forward(self, x, y): - if self.out: - out = paddle.to_tensor([True, True, True]) - z = self.func(x, y, out=out) - else: - z = self.func(x, y) + out = paddle.to_tensor([True, True, True]) + z = self.func(x, y, out=out) return paddle.cast(z, "int32") - class LogicalOp_not(LogicalOp): - @paddle.jit.to_static - def forward(self, x): - if self.out: - out = paddle.to_tensor([True, True, True]) - z = self.func(x, out=out) - else: - z = self.func(x) - return paddle.cast(z, "int32") - - op_list = [ - "logical_or", - "logical_xor", - "logical_and", - ] x = paddle.to_tensor([True]) y = paddle.to_tensor([True, False, True, False]) - for op_name in op_list: - verify_model(LogicalOp(op_name, False), [x, y]) - verify_model(LogicalOp(op_name, True), [x, y]) - verify_model(LogicalOp_not("logical_not", False), [y]) - verify_model(LogicalOp_not("logical_not", True), [y]) + verify_model(LogicalAPI("logical_and"), [x, y]) + verify_model(LogicalAPI("logical_or"), [x, y]) + verify_model(LogicalAPI("logical_xor"), [x, y]) @tvm.testing.uses_gpu @@ -755,7 +617,7 @@ def forward(self, inputs): input_data = paddle.randint(0, 10, input_shape, dtype="int32") weight = paddle.rand([10, 4], dtype="float32") verify_model(look_up, input_data=[input_data, weight]) - verify_model(LookUp(), input_data, input_shape=[[-1, -1, -1, -1]]) + verify_model(LookUp(), input_data=input_data) @tvm.testing.uses_gpu @@ -808,44 +670,6 @@ def forward(self, input1, input2): verify_model(MatMul1(), input_data=[input_data1, input_data2]) -@tvm.testing.uses_gpu -def test_forward_meshgrid(): - @paddle.jit.to_static - def t(x, y, z): - return paddle.meshgrid(x, y, z) - - x = paddle.randint(low=0, high=100, shape=[2]) - y = paddle.randint(low=0, high=100, shape=[3]) - z = paddle.randint(low=0, high=100, shape=[5]) - verify_model(t, [x, y, z]) - - -def test_forward_mm(): - class Mm(nn.Layer): - def forward(self, input1, input2): - return paddle.mm(input1, input2) - - # matrix x vector - input_data1 = paddle.randn((3, 4), dtype="float32") - input_data2 = paddle.randn((4,), dtype="float32") - verify_model(Mm(), input_data=[input_data1, input_data2]) - - # matrix x matrix - input_data1 = paddle.randn((5, 4), dtype="float32") - input_data2 = paddle.randn((4, 5), dtype="float32") - verify_model(Mm(), input_data=[input_data1, input_data2]) - - # batched matrix x batched matrix - input_data1 = paddle.randn((10, 3, 4), dtype="float32") - input_data2 = paddle.randn((10, 4, 5), dtype="float32") - verify_model(Mm(), input_data=[input_data1, input_data2]) - - # batched matrix x broadcasted matrix - input_data1 = paddle.randn((10, 3, 4), dtype="float32") - input_data2 = paddle.randn((4, 5), dtype="float32") - verify_model(Mm(), input_data=[input_data1, input_data2]) - - @tvm.testing.uses_gpu def test_forward_pool2d(): @paddle.jit.to_static @@ -858,42 +682,21 @@ def pool2d2(inputs): @paddle.jit.to_static def pool2d3(inputs): - output = nn.functional.max_pool2d(inputs, kernel_size=2, stride=2, padding=0) - return output - - @paddle.jit.to_static - def pool2d4(inputs): - output, max_indices = nn.functional.max_pool2d( + return nn.functional.max_pool2d( inputs, kernel_size=2, stride=2, padding=0, return_mask=True ) - return output input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data, input_shape=[[-1, 2, 32, 32]]) + verify_model(pool2d1, input_data=input_data) verify_model(pool2d2, input_data=input_data) - input_data1 = paddle.uniform(shape=[1, 2, 1, 50], dtype="float32", min=-1, max=1) - verify_model(pool2d3, input_data=input_data1) - - -@tvm.testing.uses_gpu -def test_forward_rank(): - class Rank(nn.Layer): - @paddle.jit.to_static - def forward(self, inputs): - rank = paddle.rank(inputs) - rank = paddle.unsqueeze(rank, axis=0) - output = inputs + rank - return output - - input_shape = [1, 2, 1, 3, 1] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(Rank(), input_data=input_data) + # verify_model(pool2d3, input_data=input_data) @tvm.testing.uses_gpu def test_forward_reshape(): @paddle.jit.to_static - def reshape1(inputs, new_shape): + def reshape1(inputs, x): + new_shape = paddle.shape(x) return paddle.reshape(inputs, new_shape) @paddle.jit.to_static @@ -903,7 +706,7 @@ def reshape2(inputs): @paddle.jit.to_static def reshape3(inputs): data_shape = inputs.shape - return inputs.reshape([data_shape[1], data_shape[2], data_shape[0]]) + return inputs.reshape([data_shape[0] * data_shape[1], data_shape[2]]) @paddle.jit.to_static def reshape4(inputs, x): @@ -913,8 +716,7 @@ def reshape4(inputs, x): input_shape = [2, 1, 10, 1, 10] input_data = paddle.rand(input_shape, dtype="float32") input_data2 = paddle.randn([2, 1, 10, 10]) - new_shape = paddle.shape(input_data2) - verify_model(reshape1, [input_data, new_shape], input_shape=[[2, 1, 10, 1, 10], [4]]) + verify_model(reshape1, input_data=[input_data, input_data2]) verify_model(reshape2, input_data=input_data) verify_model(reshape3, input_data=paddle.randn((2, 3, 4))) verify_model(reshape4, input_data=[input_data, input_data2]) @@ -943,8 +745,8 @@ def scale2(inputs): @tvm.testing.uses_gpu def test_forward_slice(): @paddle.jit.to_static - def slice1(inputs, end): - return inputs[:, :, :, :end] + def slice1(inputs): + return inputs[:, :, :, :3] @paddle.jit.to_static def slice2(inputs): @@ -960,53 +762,45 @@ def slice4(inputs): x1 = paddle.to_tensor([3]) + paddle.to_tensor([1]) return inputs[:, x0:, 1:x1, :] - @paddle.jit.to_static - def slice5(inputs): - x0 = paddle.to_tensor([3]) - return inputs[:, 1::1, 2::x0, 4:10] - input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - end = paddle.to_tensor(np.array([3])) - verify_model(slice1, [input_data, end], input_shape=[[1, 3, 10, 10], [1]]) + verify_model( + slice1, + input_data=[ + input_data, + ], + ) verify_model(slice2, input_data=input_data) - verify_model(slice3, input_data=paddle.randn((4, 4))) - verify_model(slice4, input_data=input_data) - verify_model(slice5, input_data=input_data) + # need op "strided_slice" + # verify_model(slice3, input_data=paddle.randn((4, 4))) + # need op "assign_value" + # verify_model(slice4, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_math_api(): + class MathAPI(nn.Layer): + def __init__(self, api_name): + super(MathAPI, self).__init__() + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, api_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, inputs): + return self.func(inputs) + + api_list = [ + "exp", + "relu", + "tanh", + ] + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + for api_name in api_list: + verify_model(MathAPI(api_name), input_data=input_data) if __name__ == "__main__": - test_forward_add_subtract() - test_forward_argmax() - test_forward_argmin() - test_forward_argsort() - test_forward_assign() - test_forward_batch_norm() - test_forward_cast() - test_forward_concat_unsqueeze() - test_forward_conv() - test_forward_cumsum() - test_forward_dot() - test_forward_dropout() - test_forward_elemwise() - test_forward_expand() - test_forward_expand_as() - test_forward_shape_full() - test_forward_ones() - test_forward_ones_like() - test_forward_gelu() - test_forward_math() - test_forward_activation() - test_forward_check_tensor() - test_forward_layer_norm() - test_forward_logical_op() - test_forward_look_up() - test_forward_matmul() - test_forward_meshgrid() - test_forward_mm() - test_forward_multiply() - test_forward_pool2d() - test_forward_rank() - test_forward_reshape() - test_forward_scale() - test_forward_slice() + pytest.main([__file__])