diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 7470ccc92496..d6dd094f6a5b 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -115,6 +115,10 @@ ], ), ), + ( + "importer-paddle", + ("Requirements for the PaddlePaddle importer", ["paddlepaddle"]), + ), ( "importer-pytorch", ( @@ -235,6 +239,7 @@ ("onnx", None), ("onnxruntime", None), ("opencv-python", None), + ("paddlepaddle", None), ("pillow", None), ("progressbar", None), ("psutil", None), diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 76a12691d2bf..e84a259c73f9 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -25,12 +25,15 @@ from tvm.ir import IRModule from .. import analysis +from .. import ty as _ty from .. import expr as _expr +from ..loops import while_loop from .. import function as _function from .. import ty as _ty from .. import op as _op from .common import ( fold_constant, + get_relay_op, infer_shape, infer_type, infer_value, @@ -40,28 +43,122 @@ __all__ = ["from_paddle"] +def _get_pad_size(in_size, dilated_kernel_size, stride_size): + """calculate the paddings size""" + + if stride_size == 1 or in_size % stride_size == 0: + pad = max(dilated_kernel_size - stride_size, 0) + else: + pad = max(dilated_kernel_size - (in_size % stride_size), 0) + + pad_before = pad // 2 + pad_after = pad - pad_before + + return [pad_before, pad_after] + + +def _dtype_shape_promotion(inputs): + """promote data type and shape for list of tensors.""" + + dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] + + ranks = [len(infer_shape(x)) for x in inputs] + if set(ranks) == set([1, 0]): + for i, r in enumerate(ranks): + if r == 0: + inputs[i] = _op.expand_dims(inputs[i], axis=0) + + dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs) + if len(dtypes) == 1: + return inputs + max_dtype = dtype_order[max(dtypes)] + for i, input_op in enumerate(inputs): + if infer_type(input_op).checked_type.dtype != max_dtype: + inputs[i] = input_op.astype(max_dtype) + return inputs + + def shape_of(x, dtype="int32"): """Get shape of a tensor""" ttype = infer_type(x).checked_type if not _ty.is_dynamic(ttype): shape = list(ttype.shape) - return _expr.const(shape, dtype) + return _expr.const(np.array(shape), dtype) return _op.shape_of(x, dtype) -def _get_pad_size(in_size, dilated_kernel_size, stride_size): - """calculate the paddings size""" +def _infer_value(x, params): + """Try running infer_value, and if successful, return the inferred value. + Otherwise, return input""" - if stride_size == 1 or in_size % stride_size == 0: - pad = max(dilated_kernel_size - stride_size, 0) + try: + value = infer_value(x, params) + return value.numpy().tolist() + except Exception: # pylint: disable=broad-except + return x + + +def _convert_dtype_value(val): + """converts a Paddle type id to a string.""" + + convert_dtype_map = { + 21: "int8", + 20: "uint8", + 6: "float64", + 5: "float32", + 4: "float16", + 3: "int64", + 2: "int32", + 1: "int16", + 0: "bool", + } + if val not in convert_dtype_map: + msg = "Paddle data type value %d is not handled yet." % (val) + raise NotImplementedError(msg) + return convert_dtype_map[val] + + +def convert_unary_op(g, op, block): + """Operator converter for all the activation.""" + + op_map = { + "isinf_v2": _op.isinf, + "isfinite_v2": _op.isfinite, + "isnan_v2": _op.isnan, + } + if op.type in op_map: + unary_func = op_map[op.type] else: - pad = max(dilated_kernel_size - (in_size % stride_size), 0) + unary_func = get_relay_op(op.type) + out = unary_func(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) - pad_before = pad // 2 - pad_after = pad - pad_before - return [pad_before, pad_after] +def convert_addmm(g, op, block): + """Operator converter for addmm.""" + + input_x = g.get_node(op.input("Input")[0]) + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + alpha = op.attr("Alpha") + beta = op.attr("Beta") + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + + if not isinstance(alpha, _expr.Expr) and alpha != 1: + alpha = _expr.const(alpha, dtype) + x *= alpha + + if not isinstance(beta, _expr.Expr) and beta != 1: + beta = _expr.const(beta, dtype) + input_x *= beta + + transposed_y = _op.transpose(y, axes=[1, 0]) + dense_out = _op.nn.dense(x, transposed_y) + out = dense_out + input_x + g.add_node(op.output("Out")[0], out) def convert_arg_max(g, op, block): @@ -70,6 +167,8 @@ def convert_arg_max(g, op, block): axis = op.attr("axis") keepdims = op.attr("keepdims") flatten = op.attr("flatten") + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) if axis is None or flatten: @@ -77,13 +176,64 @@ def convert_arg_max(g, op, block): out = _op.argmax(x, axis=None, keepdims=True) else: out = _op.argmax(x, axis=axis, keepdims=keepdims) + if dtype != infer_type(out).checked_type.dtype: + out = _op.cast(out, dtype) g.add_node(op.output("Out")[0], out) +def convert_arg_min(g, op, block): + """Operator converter for arg_min.""" + + axis = op.attr("axis") + keepdims = op.attr("keepdims") + flatten = op.attr("flatten") + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) + + x = g.get_node(op.input("X")[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + out = _op.argmin(x, axis=None, keepdims=True) + else: + out = _op.argmin(x, axis=axis, keepdims=keepdims) + if dtype != infer_type(out).checked_type.dtype: + out = _op.cast(out, dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_argsort(g, op, block): + """Operator converter for argsort.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + descending = op.attr("descending") + + out = _op.sort(x, axis, not descending) + out_indice = _op.argsort(x, axis, not descending, dtype="int64") + g.add_node(op.output("Out")[0], out) + g.add_node(op.output("Indices")[0], out_indice) + + def convert_assign(g, op, block): """Operator converter for assign.""" - out = _op.copy(g.get_node(op.input("X")[0])) + out = g.get_node(op.input("X")[0]) + g.add_node(op.output("Out")[0], out) + + +def convert_assign_value(g, op, block): + """Operator converter for assign_value.""" + + keys = ["bool_values", "fp32_values", "int32_values", "int64_values"] + dtypes = ["bool", "float32", "int32", "int64"] + for i, key in enumerate(keys): + dtype = dtypes[i] + value = np.array(op.attr(key)).astype(dtype) + if value is not None and value.size >= 1: + break + shape = op.attr("shape") + value = value.reshape(shape) + out = _op.const(value, dtype=dtype) g.add_node(op.output("Out")[0], out) @@ -107,21 +257,71 @@ def convert_batch_norm(g, op, block): g.add_node(op.output("Y")[0], out[0]) +def convert_bmm(g, op, block): + """Operator converter for bmm.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + y = _op.transpose(y, [0, 2, 1]) + out = _op.nn.batch_matmul(x, y) + g.add_node(op.output("Out")[0], out) + + def convert_cast(g, op, block): """Operator converter for cast.""" - dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] + dtype = op.attr("out_dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) out = _op.cast(x, dtype=dtype) g.add_node(op.output("Out")[0], out) +def convert_clip(g, op, block): + """Operator converter for clip.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + is_dynamic = False + if op.input("Min"): + min_value = g.get_node(op.input("Min")[0]) + min_value = _infer_value(min_value, g.get_params()) + if isinstance(min_value, _expr.Expr): + is_dynamic = True + else: + min_value = min_value[0] + else: + min_value = op.attr("min") + if op.input("Max"): + max_value = g.get_node(op.input("Max")[0]) + max_value = _infer_value(max_value, g.get_params()) + if isinstance(max_value, _expr.Expr): + if not is_dynamic: + is_dynamic = True + min_value = _op.const(min_value, dtype) + else: + max_value = max_value[0] + if is_dynamic: + max_value = _op.const(max_value, dtype) + else: + max_value = op.attr("max") + if is_dynamic: + max_value = _op.const(max_value, dtype) + + if not is_dynamic: + out = _op.clip(x, min_value, max_value) + else: + out = _op.maximum(x, min_value) + out = _op.minimum(out, max_value) + g.add_node(op.output("Out")[0], out) + + def convert_concat(g, op, block): """Operator converter for concat.""" inputs = [g.get_node(op.input("X")[i]) for i in range(len(op.input("X")))] axis = op.attr("axis") + inputs = _dtype_shape_promotion(inputs) out = _op.concatenate(inputs, axis=axis) g.add_node(op.output("Out")[0], out) @@ -138,12 +338,22 @@ def convert_conv2d(g, op, block): kernel = g.get_node(op.input("Filter")[0]) input_x = g.get_node(op.input("Input")[0]) out_channels, _, k_h, k_w = infer_shape(kernel) - in_h, in_w = infer_shape(input_x)[2:] if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) + if strides[0] == 1 and strides[1] == 1: + pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) + else: + input_shape = shape_of(input_x) + h_w = _op.strided_slice(input_shape, [2], [4]) + try: + in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() + except Exception as e: + msg = "The SAME padding algorithm of Conv not support dynamic shape" + raise tvm.error.OpAttributeInvalid(msg) from e + pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -167,6 +377,90 @@ def convert_conv2d(g, op, block): g.add_node(op.output("Output")[0], out) +def convert_conv2d_transpose(g, op, block): + """Operator converter for conv2d_transpose.""" + + dilations = op.attr("dilations") + groups = op.attr("groups") + paddings = op.attr("paddings") + padding_algorithm = op.attr("padding_algorithm") + strides = op.attr("strides") + output_padding = op.attr("output_padding") if op.attr("output_padding") else [0, 0] + + kernel = g.get_node(op.input("Filter")[0]) + input_x = g.get_node(op.input("Input")[0]) + _, out_channels, k_h, k_w = infer_shape(kernel) + if padding_algorithm == "VALID": + paddings = [0, 0] + elif padding_algorithm == "SAME": + if strides[0] == 1 and strides[1] == 1: + pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) + else: + input_shape = shape_of(input_x) + h_w = _op.strided_slice(input_shape, [2], [4]) + try: + in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() + except Exception as e: + msg = "The SAME padding algorithm of Conv_Transpose not support dynamic shape" + raise tvm.error.OpAttributeInvalid(msg) from e + pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) + paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + elif padding_algorithm == "EXPLICIT": + if len(paddings) == 2: + paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] + if len(paddings) == 4: + paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] + else: + msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + + out = _op.nn.conv2d_transpose( + input_x, + kernel, + strides=strides, + padding=paddings, + dilation=dilations, + groups=groups, + channels=out_channels, + kernel_size=[k_h, k_w], + output_padding=output_padding, + ) + g.add_node(op.output("Output")[0], out) + + +def convert_crop(g, op, block): + """Operator converter for crop.""" + + x = g.get_node(op.input("X")[0]) + dims = len(infer_shape(x)) + input_shape = op.input("Shape") + input_offsets = op.input("Offsets") + if input_shape: + shape = g.get_node(input_shape[0]) + shape = _infer_value(shape, g.get_params()) + else: + shape = op.attr("shape") + + if input_offsets: + offsets = g.get_node(input_offsets[0]) + offsets = _infer_value(offsets, g.get_params()) + else: + offsets = op.attr("offsets") + + if not isinstance(shape, _expr.Expr): + shape = _op.const(shape, "int32") + if not isinstance(offsets, _expr.Expr): + offsets = _op.const(offsets, "int32") + slice_start = offsets + slice_end = _op.add(shape, offsets) + strides = _op.const([1] * dims, dtype="int32") + + out = _op.strided_slice(x, slice_start, slice_end, strides) + g.add_node(op.output("Out")[0], out) + + def convert_cumsum(g, op, block): """Operator converter for cumsum.""" @@ -191,7 +485,51 @@ def convert_dropout(g, op, block): """Operator converter for dropout.""" x = g.get_node(op.input("X")[0]) - out = _op.copy(x) + g.add_node(op.output("Out")[0], x) + + +def convert_elu(g, op, block): + """Operator converter for elu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + alpha = op.attr("alpha") + alpha = _expr.const(-1.0 * alpha, dtype=dtype) + out = alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(x)) + _op.nn.relu(x) + g.add_node(op.output("Out")[0], out) + + +def convert_dist(g, op, block): + """Operator converter for dist.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + dtype = infer_type(x).checked_type.dtype + p = op.attr("p") + + x -= y + if p == np.inf: + out = _op.reduce.max(_op.abs(x)) + elif p == np.NINF: + out = _op.reduce.min(_op.abs(x)) + else: + reci_order = _expr.const(1.0 / p, dtype=dtype) + p = _expr.const(p) + out = _op.power( + _op.reduce.sum(_op.power(_op.abs(x), p)), + reci_order, + ) + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_dot(g, op, block): + """Operator converter for dot.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + out = _op.sum(_op.multiply(x, y), axis=[-1], keepdims=True) g.add_node(op.output("Out")[0], out) @@ -199,49 +537,59 @@ def convert_elementwise_op(g, op, block): """Operator converter for all the elementwise operators.""" op_map = { - "elementwise_div": lambda x, y: x / y, - "elementwise_add": lambda x, y: x + y, - "elementwise_mul": lambda x, y: x * y, - "elementwise_sub": lambda x, y: x - y, - "elementwise_mod": lambda x, y: x % y, + "elementwise_div": "divide", + "elementwise_add": "add", + "elementwise_mul": "multiply", + "elementwise_sub": "subtract", + "elementwise_mod": "mod", + "elementwise_max": "maximum", + "elementwise_min": "minimum", + "elementwise_pow": "power", + "elementwise_floordiv": "floor_divide", + "floor_mod": "floor_mod", + "equal": "equal", + "greater_equal": "greater_equal", + "greater_than": "greater", + "less_equal": "less_equal", + "less_than": "less", + "not_equal": "not_equal", } op_func = op_map[op.type] ipt0 = g.get_node(op.input("X")[0]) ipt1 = g.get_node(op.input("Y")[0]) - ipt0_shape = block.var(op.input("X")[0]).shape - ipt1_shape = block.var(op.input("Y")[0]).shape + ipt0_shape = infer_shape(ipt0) + ipt1_shape = infer_shape(ipt1) axis = op.attr("axis") if len(ipt0_shape) != len(ipt1_shape): if axis < 0: axis = axis + len(ipt0_shape) if axis != len(ipt0_shape) - 1: ipt1 = _op.expand_dims(ipt1, axis=axis, num_newaxis=(len(ipt0_shape) - axis - 1)) + op_func = get_relay_op(op_func) out = op_func(ipt0, ipt1) g.add_node(op.output("Out")[0], out) -def convert_equal(g, op, block): - """Operator converter for equal.""" +def convert_expand(g, op, block): + """Operator converter for expand.""" x = g.get_node(op.input("X")[0]) - y = g.get_node(op.input("Y")[0]) - out = _op.equal(x, y) + if op.input("Shape"): + sizes = g.get_node(op.input("Shape")[0]) + sizes = _infer_value(sizes, g.get_params()) + else: + sizes = op.attr("shape") + + out = _op.broadcast_to(x, sizes) g.add_node(op.output("Out")[0], out) -def convert_activation(g, op, block): - """Operator converter for all the activation.""" +def convert_expand_as(g, op, block): + """Operator converter for expand_as.""" - op_map = { - "exp": _op.exp, - "relu": _op.nn.relu, - "tanh": _op.tanh, - "sqrt": _op.sqrt, - "erf": _op.erf, - "abs": _op.abs, - } - act_func = op_map[op.type] - out = act_func(g.get_node(op.input("X")[0])) + x = g.get_node(op.input("X")[0]) + target_shape = op.attr("target_shape") + out = _op.broadcast_to(x, target_shape) g.add_node(op.output("Out")[0], out) @@ -259,43 +607,32 @@ def convert_feed(g, op, block): ipt_name = op.name if g.shape_dict is not None: ipt_shape = g.shape_dict[ipt_name] + + if isinstance(ipt_shape, tuple): + ipt_shape = list(ipt_shape) + for i, s in enumerate(ipt_shape): + if s < 0: + ipt_shape[i] = _ty.Any() out = new_var(ipt_name, shape=ipt_shape, dtype=ipt_dtype) g.add_node(ipt_name, out) -def convert_fill_any_like(g, op, block): - """Operator converter for fill_any_like.""" - - out_name = op.output("Out")[0] - out_dtype = block.var(out_name).dtype - out_dtype = str(out_dtype).strip().split(".")[1] - x = g.get_node(op.input("X")[0]) - ipt_type = infer_type(x).checked_type - value = op.attr("value") - if not _ty.is_dynamic(ipt_type): - shape = infer_shape(x) - const = np.ones(shape) * value - out = _expr.const(const.astype(out_dtype)) - else: - out = _op.transform.full_like(x, value).astype(out_dtype) - g.add_node(op.output("Out")[0], out) - - def convert_fill_constant(g, op, block): """Operator converter for fill_constant.""" value = op.attr("value") shape = block.var(op.output("Out")[0]).shape - dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] - if op.input("ValueTensor"): + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) + value = _expr.const(value).astype(dtype) + if "ValueTensor" in op.input_names and op.input("ValueTensor"): shape = g.get_node(op.input("ValueTensor")[0]) - shape = infer_value(shape, g.get_params()).numpy() - if op.input("ShapeTensor"): + shape = _infer_value(shape, g.get_params()) + if "ShapeTensor" in op.input_names and op.input("ShapeTensor"): shape = g.get_node(op.input("ShapeTensor")[0]) - shape = infer_value(shape, g.get_params()).numpy() - value = np.full(shape, value, dtype) - out = _expr.const(value.astype(dtype)).astype(dtype) + shape = _infer_value(shape, g.get_params()) + + out = _op.full(value, shape=shape, dtype=dtype) g.add_node(op.output("Out")[0], out) @@ -310,12 +647,25 @@ def convert_gelu(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_hard_shrink(g, op, block): + """Operator converter for hard_shrink.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + threshold = op.attr("threshold") + threshold = _op.const(threshold, dtype) + out = _op.logical_or(x < _op.const(-1.0, dtype) * threshold, x > threshold) + out = _op.cast(out, dtype) * x + g.add_node(op.output("Out")[0], out) + + def convert_hard_sigmoid(g, op, block): """Operator converter for hard_sigmoid.""" slope = op.attr("slope") x = g.get_node(op.input("X")[0]) - out = x * _expr.const(slope) + _expr.const(0.5) + dtype = infer_type(x).checked_type.dtype + out = x * _expr.const(slope, dtype) + _expr.const(0.5, dtype) out = _op.clip(out, 0, 1) g.add_node(op.output("Out")[0], out) @@ -330,12 +680,23 @@ def convert_hard_swish(g, op, block): assert np.isclose(scale, 6.0), "Only support scale==6.0 for PaddlePaddle's hard_swish" assert np.isclose(threshold, 6.0), "Only support threshold==6.0 for PaddlePaddle's hard_swish" x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype out = _op.clip(x, -1 * offset, offset) - out = out / _expr.const(threshold) + _expr.const(0.5) + out = out / _expr.const(threshold, dtype) + _expr.const(0.5, dtype) out = x * out g.add_node(op.output("Out")[0], out) +def convert_hard_tanh(g, op, block): + """Operator converter for hard_tanh.""" + + x = g.get_node(op.input("X")[0]) + t_max = op.attr("t_max") + t_min = op.attr("t_min") + out = _op.tensor.clip(x, t_min, t_max) + g.add_node(op.output("Out")[0], out) + + def convert_layer_norm(g, op, block): """Operator converter for layer_norm.""" @@ -376,16 +737,55 @@ def convert_leaky_relu(g, op, block): g.add_node(op.output("Out")[0], out) -def convert_lookup_table(g, op, block): - """Operator converter for lookup_table_v2.""" +def convert_log1p(g, op, block): + """Operator converter for log1p.""" - indices = g.get_node(op.input("Ids")[0]) - padding_idx = op.attr("padding_idx") - if padding_idx != -1: - g.get_params[op.input("W")[0]][padding_idx] = 0.0 - g.add_node(op.input("W")[0], _expr.const(g.params[op.input("W")[0]])) - weights = g.get_node(op.input("W")[0]) - out = _op.take(weights, indices.astype("int32"), axis=0) + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + one = _expr.const(1, dtype=dtype) + out = _op.log(x + one) + g.add_node(op.output("Out")[0], out) + + +def convert_logical_op(g, op, block): + """Operator converter for logical op.""" + + ipt0 = g.get_node(op.input("X")[0]) + ipt1 = g.get_node(op.input("Y")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0, ipt1) + g.add_node(op.output("Out")[0], out) + + +def convert_logical_not(g, op, block): + """Operator converter for logical_not op.""" + + ipt0 = g.get_node(op.input("X")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0) + g.add_node(op.output("Out")[0], out) + + +def convert_logsigmoid(g, op, block): + """Operator converter for logsigmoid.""" + + x = g.get_node(op.input("X")[0]) + out = _op.log(_op.tensor.sigmoid(x)) + g.add_node(op.output("Out")[0], out) + + +def convert_logsoftmax(g, op, block): + """Operator converter for logsoftmax.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + ndim = len(infer_shape(x)) + if axis < 0: + axis += ndim + m = _op.max(x, [axis], keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, [axis], keepdims=True) + out = x - m - _op.log(s) g.add_node(op.output("Out")[0], out) @@ -498,6 +898,16 @@ def flatten_to_nd(x, x_shape, nd=3): g.add_node(op.output("Out")[0], out) +def convert_meshgrid(g, op, block): + """Operator converter for meshgrid.""" + + inputs = op.input("X") + x = [g.get_node(i) for i in inputs] + outs = _op.meshgrid(x, indexing="ij") + for i, out in enumerate(outs): + g.add_node(op.output("Out")[i], out) + + def convert_mul(g, op, block): """Operator converter for mul.""" @@ -505,8 +915,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = shape_of(x) - y_shape = shape_of(y) + x_shape = _op.shape_of(x) + y_shape = _op.shape_of(y) x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -543,6 +953,37 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_mv(g, op, block): + """Operator converter for mv.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Vec")[0]) + y = _op.expand_dims(y, axis=-1) + y = _op.transpose(y) + out = _op.nn.dense(x, y) + out = _op.squeeze(out, axis=[-1]) + g.add_node(op.output("Out")[0], out) + + +def convert_numel(g, op, block): + """Operator converter for numel.""" + + input_x = g.get_node(op.input("Input")[0]) + out = _op.ndarray_size(input_x, dtype="int64") + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_nonzero(g, op, block): + """Operator converter for nonzero.""" + + input_x = g.get_node(op.input("Condition")[0]) + out = _op.transform.argwhere(input_x) + # Paddle NonZero always outputs int64 + out = _op.cast(out, "int64") + g.add_node(op.output("Out")[0], out) + + def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -558,7 +999,7 @@ def convert_pool2d(g, op, block): ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - in_h, in_w = infer_shape(input_x)[2:] + _, _, in_h, in_w = infer_shape(input_x) op_map = { "avg": "avg_pool2d", @@ -575,8 +1016,19 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) + if strides[0] == 1 and strides[1] == 1: + pad_h = _get_pad_size(0, ksize[0], strides[0]) + pad_w = _get_pad_size(0, ksize[1], strides[1]) + else: + input_shape = shape_of(input_x) + h_w = _op.strided_slice(input_shape, [2], [4]) + try: + in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() + except Exception as e: + msg = "The SAME padding algorithm of Conv not support dynamic shape" + raise tvm.error.OpAttributeInvalid(msg) from e + pad_h = _get_pad_size(in_h, ksize[0], strides[0]) + pad_w = _get_pad_size(in_w, ksize[1], strides[1]) paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -587,6 +1039,11 @@ def convert_pool2d(g, op, block): msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + ksize[0] = in_h + if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + ksize[1] = in_w + if not adaptive: out = getattr(_op.nn, op_map[pooling_type])( input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode @@ -605,18 +1062,14 @@ def convert_reshape(g, op, block): if input_shape: new_shape = g.get_node(input_shape[0]) elif input_shape_tensor: - tmp_shape = [] + new_shape = [] for shape_name in input_shape_tensor: shape = g.get_node(shape_name) if len(infer_shape(shape)) == 0: shape = _op.reshape(shape, [-1]) - if isinstance(shape, _expr.Constant): - tmp_shape.append(shape) - elif isinstance(shape, _expr.Expr): - tmp_shape.append(shape) - else: - tmp_shape.append(_expr.const(np.array(shape).astype("int64"))) - new_shape = _op.concatenate(tmp_shape, axis=0) + new_shape.append(shape.astype("int64")) + new_shape = _op.concatenate(new_shape, axis=0) + new_shape = _infer_value(new_shape, g.get_params()) else: new_shape = op.attr("shape") out = _op.reshape(data, new_shape) @@ -631,8 +1084,11 @@ def convert_scale(g, op, block): bias_after_scale = op.attr("bias_after_scale") x = g.get_node(op.input("X")[0]) if np.isclose(scale, 1.0) and np.isclose(bias, 0.0): - out = _op.copy(x) + out = x else: + x_dtype = infer_type(x).checked_type.dtype + if x_dtype != "float32": + x = x.astype("float32") if np.isclose(bias, 0.0): out = x * _expr.const(np.array(scale).astype("float32")) elif np.isclose(scale, 1.0): @@ -646,6 +1102,58 @@ def convert_scale(g, op, block): out = (x + _expr.const(np.array(bias).astype("float32"))) * _expr.const( np.array(scale).astype("float32") ) + if x_dtype != "float32": + out = out.astype(x_dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_scatter(g, op, block): + """Operator converter for scatter.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Ids")[0]) + updates = g.get_node(op.input("Updates")[0]) + overwrite = op.attr("overwrite") + + shape = infer_shape(updates) + ndims = len(shape) + index = _op.expand_dims(index, axis=-1, num_newaxis=ndims - 1) + index = _op.transform.broadcast_to(index, shape) + + if overwrite: + out = _op.scatter(x, index, updates, axis=0) + else: + out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0) + out += _op.scatter(x, index, _op.zeros_like(updates), axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_scatter_nd_add(g, op, block): + """Operator converter for scatter_nd_add.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + updates = g.get_node(op.input("Updates")[0]) + indices_dim = len(infer_shape(index)) + axes = list(range(indices_dim)) + index = _op.transpose(index, axes[-1:] + axes[:-1]) + out = _op.scatter_nd(x, index, updates, mode="add") + g.add_node(op.output("Out")[0], out) + + +def convert_selu(g, op, block): + """Operator converter for selu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(op.attr("alpha"), dtype) + scale = _op.const(op.attr("scale"), dtype) + out = ( + _expr.const(-1.0, dtype=dtype) + * alpha + * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(x)) + ) + out = scale * (out + _op.nn.relu(x)) g.add_node(op.output("Out")[0], out) @@ -660,38 +1168,107 @@ def convert_shape(g, op, block): def convert_slice(g, op, block): """Operator converter for slice.""" - def parameter_process(starts, ends, axes, dshape): - new_axes = [] - new_starts = [] - new_ends = [] - pop_index = 0 - for i in range(max(axes) + 1): - new_axes.append(i) - if i in axes: - new_starts.append(starts[pop_index]) - new_ends.append(ends[pop_index]) - pop_index += 1 - else: - new_starts.append(0) - new_ends.append(dshape[i]) - return new_starts, new_ends, new_axes - data = g.get_node(op.input("Input")[0]) - dshape = infer_shape(data) - starts = op.attr("starts") - ends = op.attr("ends") + dims = len(infer_shape(data)) + axes = op.attr("axes") + indices = _expr.const(axes, dtype="int64") + decrease_axis = op.attr("decrease_axis") - if isinstance(starts, int): - starts = [starts] - if isinstance(ends, int): - ends = [ends] - if isinstance(axes, int): - axes = [axes] if isinstance(decrease_axis, int): decrease_axis = [decrease_axis] - starts, ends, axes = parameter_process(starts, ends, axes, dshape) - out = _op.strided_slice(data, begin=starts, end=ends) + + if op.input("StartsTensor"): + starts = g.get_node(op.input("StartsTensor")[0]) + starts = _infer_value(starts, g.get_params()) + elif op.input("StartsTensorList"): + starts = [] + for start_index in op.input("StartsTensorList"): + start_index = g.get_node(start_index).astype("int64") + starts.append(start_index) + starts = _op.concatenate(starts, axis=0) + starts = _infer_value(starts, g.get_params()) + else: + starts = op.attr("starts") + + if len(axes) < dims: + if isinstance(starts, _expr.Expr): + starts = _op.scatter( + _op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype), + indices, + starts, + axis=0, + ) + else: + base = [0] * dims + for i, axis in enumerate(axes): + base[axis] = starts[i] + starts = base + + if op.input("EndsTensor"): + ends = g.get_node(op.input("EndsTensor")[0]) + ends = _infer_value(ends, g.get_params()) + elif op.input("EndsTensorList"): + ends = [] + for end_index in op.input("EndsTensorList"): + end_index = g.get_node(end_index).astype("int64") + ends.append(end_index) + ends = _op.concatenate(ends, axis=0) + ends = _infer_value(ends, g.get_params()) + else: + ends = op.attr("ends") + + if len(axes) < dims: + if isinstance(ends, _expr.Expr): + ends = _op.scatter( + _expr.const( + np.array([np.iinfo(np.int32).max] * dims), + dtype=infer_type(ends).checked_type.dtype, + ), + indices, + ends, + axis=0, + ) + else: + base = [np.iinfo(np.int32).max] * dims + for i, axis in enumerate(axes): + base[axis] = ends[i] + ends = base + + strides = None + if "StridesTensor" in op.input_names and op.input("StridesTensor"): + strides = g.get_node(op.input("StridesTensor")[0]) + strides = _infer_value(strides, g.get_params()) + elif "StridesTensorList" in op.input_names and op.input("StridesTensorList"): + strides = [] + for strides_index in op.input("StridesTensorList"): + strides_index = g.get_node(strides_index).astype("int64") + strides.append(strides_index) + strides = _op.concatenate(strides, axis=0) + strides = _infer_value(strides, g.get_params()) + elif op.has_attr("strides"): + strides = op.attr("strides") + + if len(axes) < dims: + if isinstance(strides, _expr.Expr): + strides = _op.scatter( + _expr.const( + np.array([1] * dims), + dtype=infer_type(strides).checked_type.dtype, + ), + indices, + strides, + axis=0, + ) + elif strides: + base = [1] * dims + for i, axis in enumerate(axes): + base[axis] = strides[i] + strides = base + if not strides: + strides = _op.const([1] * dims, dtype="int64") + + out = _op.strided_slice(data, begin=starts, end=ends, strides=strides) if decrease_axis: out = _op.squeeze(out, axis=decrease_axis) g.add_node(op.output("Out")[0], out) @@ -722,41 +1299,105 @@ def convert_unsqueeze(g, op, block): _convert_map = { + "abs": convert_unary_op, + "acos": convert_unary_op, + "addmm": convert_addmm, "arg_max": convert_arg_max, + "arg_min": convert_arg_min, + "argsort": convert_argsort, + "asin": convert_unary_op, "assign": convert_assign, + "assign_value": convert_assign_value, + "atan": convert_unary_op, "batch_norm": convert_batch_norm, + "bmm": convert_bmm, + "brelu": convert_hard_tanh, "cast": convert_cast, + "ceil": convert_unary_op, + "clip": convert_clip, "concat": convert_concat, "conv2d": convert_conv2d, + "conv2d_transpose": convert_conv2d_transpose, + "cos": convert_unary_op, + "cosh": convert_unary_op, + "crop_tensor": convert_crop, "cumsum": convert_cumsum, "depthwise_conv2d": convert_conv2d, + "dist": convert_dist, + "dot": convert_dot, "dropout": convert_dropout, "elementwise_add": convert_elementwise_op, "elementwise_div": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, - "equal": convert_equal, - "exp": convert_activation, + "elementwise_mod": convert_elementwise_op, + "elementwise_max": convert_elementwise_op, + "elementwise_min": convert_elementwise_op, + "elementwise_pow": convert_elementwise_op, + "elementwise_floordiv": convert_elementwise_op, + "elu": convert_elu, + "equal": convert_elementwise_op, + "erf": convert_unary_op, + "exp": convert_unary_op, + "expand_v2": convert_expand, + "expand_as_v2": convert_expand_as, "feed": convert_feed, - "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, + "floor": convert_unary_op, + "floor_mod": convert_elementwise_op, "gelu": convert_gelu, + "greater_equal": convert_elementwise_op, + "greater_than": convert_elementwise_op, + "hard_shrink": convert_hard_shrink, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, + "isfinite": convert_unary_op, + "isfinite_v2": convert_unary_op, + "isinf": convert_unary_op, + "isinf_v2": convert_unary_op, + "isnan": convert_unary_op, + "isnan_v2": convert_unary_op, "layer_norm": convert_layer_norm, "leaky_relu": convert_leaky_relu, - "lookup_table_v2": convert_lookup_table, + "less_equal": convert_elementwise_op, + "less_than": convert_elementwise_op, + "log": convert_unary_op, + "log2": convert_unary_op, + "log10": convert_unary_op, + "log1p": convert_log1p, + "logical_and": convert_logical_op, + "logical_not": convert_logical_not, + "logical_or": convert_logical_op, + "logical_xor": convert_logical_op, + "logsigmoid": convert_logsigmoid, + "log_softmax": convert_logsoftmax, "matmul": convert_matmul, "matmul_v2": convert_matmul, + "meshgrid": convert_meshgrid, + "mv": convert_mv, "mul": convert_mul, + "not_equal": convert_elementwise_op, "pool2d": convert_pool2d, - "relu": convert_activation, + "relu": convert_unary_op, "reshape2": convert_reshape, + "round": convert_unary_op, + "rsqrt": convert_unary_op, "scale": convert_scale, + "scatter": convert_scatter, + "scatter_nd_add": convert_scatter_nd_add, + "selu": convert_selu, "shape": convert_shape, + "sigmoid": convert_unary_op, + "sign": convert_unary_op, + "sin": convert_unary_op, + "sinh": convert_unary_op, + "size": convert_numel, "slice": convert_slice, "softmax": convert_softmax, - "tanh": convert_activation, + "sqrt": convert_unary_op, + "strided_slice": convert_slice, + "tan": convert_unary_op, + "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } @@ -764,21 +1405,24 @@ def convert_unsqueeze(g, op, block): class GraphProto: """A helper class for handling relay functions from PaddlePaddle model.""" - def __init__(self): + def __init__(self, freeze_params=False): self.nodes = {} self.params = {} self.shape_dict = None + self.freeze_params = freeze_params def get_node(self, name): """get node from graph""" - assert name in self.nodes + assert name in self.nodes, "Node: {} not found".format(name) return self.nodes[name] def add_node(self, name, node): """add a node to graph""" - - self.nodes[name] = fold_constant(node) + if self.shape_dict: + self.nodes[name] = fold_constant(node) + else: + self.nodes[name] = node def get_params(self, name=None): """get params from graph""" @@ -788,6 +1432,11 @@ def get_params(self, name=None): assert name in self.params return self.params[name] + def set_params(self, params): + """set params for graph""" + + self.params = params + def extract_parameters(self, program, scope=None): """Extract all the weights from PaddlePaddle program.""" @@ -803,7 +1452,12 @@ def extract_parameters(self, program, scope=None): self.params[name] = scope[name] else: self.params[name] = np.array(scope.var(name).get_tensor()) - self.nodes[name] = _expr.const(self.params[name]) + if self.freeze_params: + self.nodes[name] = _expr.const(self.params[name]) + else: + self.nodes[name] = _expr.var( + name, shape=self.params[name].shape, dtype=str(self.params[name].dtype) + ) def check_input_shape(self, op, block): """Check the shape information of model's inputs, fixed shape is recommended.""" @@ -839,12 +1493,13 @@ def ops_to_relay(self, program, input_specs=None): if input_specs is not None: for input_spec in input_specs: convert_feed(self, input_spec, None) - for block in program.blocks: - for op in block.ops: - if op.type == "fetch": - continue + global_block = program.blocks[0] + for op in global_block.ops: + if op.type == "fetch": + continue + else: convert_func = _convert_map[op.type] - convert_func(self, op, block) + convert_func(self, op, global_block) def from_program(self, program, shape_dict, scope): """Construct the TVM relay expression from PaddlePaddle program.""" @@ -864,12 +1519,14 @@ def from_program(self, program, shape_dict, scope): if op.type == "fetch": output_names.append(op.input("X")[0]) - outputs = [self.nodes[name] for name in output_names] + outputs = [self.get_node(name) for name in output_names] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + if self.freeze_params: + self.params = {} return mod, self.params def from_translated_layer(self, layer, shape_dict): @@ -888,25 +1545,27 @@ def from_translated_layer(self, layer, shape_dict): output_names = [x.name for x in layer._output_spec()] - outputs = [self.nodes[name] for name in output_names] + outputs = [self.get_node(name) for name in output_names] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + if self.freeze_params: + self.params = {} return mod, self.params -def from_paddle(program_or_layer, shape_dict=None, scope=None): +def from_paddle(program_or_layer, shape_dict=None, scope=None, freeze_params=False): """Convert a PaddlePaddle model into an equivalent Relay Function. - - PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, - and PaddlePaddle scope stores all the weights of PaddlePaddle model. + PaddlePaddle Program/TranslatedLayer represent the computation + graph of PaddlePaddle model, and PaddlePaddle scope stores all the + weights of PaddlePaddle model. """ import paddle - g = GraphProto() + g = GraphProto(freeze_params) if isinstance(program_or_layer, paddle.jit.TranslatedLayer): # model is loaded by `paddle.jit.load` mod, params = g.from_translated_layer(program_or_layer, shape_dict) diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index db07e07f9d83..4dba1db049f1 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -14,19 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os from pathlib import Path import shutil import numpy as np + +import paddle + +import paddle.nn as nn + import tvm import tvm.testing import tvm.topi.testing from tvm import relay from tvm.contrib import graph_executor -import paddle -import paddle.nn as nn PADDLE_TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data", "paddle") PADDLE_TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) @@ -34,14 +36,16 @@ def assert_shapes_match(tru, est): if tru.shape != est.shape: - msg = "Output shapes {} and {} don't match" + msg = "Paddle Output shapes {} and TVM shapes {} don't match" raise AssertionError(msg.format(tru.shape, est.shape)) + if tru.dtype != est.dtype: + msg = "Paddle Output dtype {} and TVM dtype {} don't match" + raise AssertionError(msg.format(tru.dtype, est.dtype)) def get_paddle_model(func, input_spec): global PADDLE_TEST_DATA_ROOT_PATH model_path = Path(PADDLE_TEST_DATA_ROOT_PATH, "model") - paddle.jit.save(func, str(model_path), input_spec=input_spec) baseline_model = paddle.jit.load(str(model_path)) @@ -49,8 +53,35 @@ def get_paddle_model(func, input_spec): return baseline_model -def verify_model(func, input_data, rtol=1e-5, atol=1e-5): - if not (isinstance(input_data, (tuple, list))): +def get_tvm_output_with_vm(mod, params, target, device, input_data): + """Generic function to execute and get tvm output with vm executor""" + + ex = relay.create_executor("vm", mod=mod, device=device, target=target) + params.update(input_data) + result = ex.evaluate()(**params) + if isinstance(result, tvm.runtime.NDArray): + return [ + result.numpy(), + ] + return [r.numpy() for r in result] + + +def get_tvm_output(mod, params, target, device, input_data, compiled_names, num): + """Generic function to execute and get tvm output""" + + lib = relay.build(mod, target=target, params=params) + gmod = graph_executor.GraphModule(lib["default"](device)) + for name in compiled_names: + gmod.set_input(name, input_data[name]) + gmod.run() + outputs = [] + for i in range(num): + outputs.append(gmod.get_output(i).numpy()) + return outputs + + +def verify_model(func, input_data, rtol=1e-5, atol=1e-5, input_shape=None): + if not isinstance(input_data, (tuple, list)): input_data = [input_data] input_spec = [] @@ -59,11 +90,13 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): compiled_input = {} for idx, data in enumerate(input_data): input_name = "input{}".format(idx) - input_spec.append( - paddle.static.InputSpec(dtype=data.dtype, shape=data.shape, name=input_name) - ) + if input_shape: + shape = input_shape[idx] + else: + shape = data.shape + input_shape_dict[input_name] = shape + input_spec.append(paddle.static.InputSpec(dtype=data.dtype, shape=shape, name=input_name)) input_names.append(input_name) - input_shape_dict[input_name] = data.shape if isinstance(data, np.ndarray): compiled_input[input_name] = data else: @@ -81,25 +114,72 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) parms_num = min(len(input_names), len(mod["main"].params)) compiled_names = [] - for arg in mod["main"].params[:parms_num]: - assert arg.name_hint in input_names - compiled_names.append(arg.name_hint) + for arg in mod["main"].params: + assert arg.name_hint in input_names or arg.name_hint in params + if arg.name_hint in input_names: + compiled_names.append(arg.name_hint) with tvm.transform.PassContext(opt_level=3): for target, dev in tvm.testing.enabled_targets(): - lib = relay.build(mod, target=target, params=params) - gmod = graph_executor.GraphModule(lib["default"](dev)) - for name in compiled_names: - gmod.set_input(name, compiled_input[name]) - gmod.run() - - for i, baseline_output in enumerate(baseline_outputs): - compiled_output = gmod.get_output(i).numpy() - + if input_shape: + tvm_output = get_tvm_output_with_vm(mod, params, target, dev, compiled_input) + else: + tvm_output = get_tvm_output( + mod, params, target, dev, compiled_input, compiled_names, len(baseline_outputs) + ) + + for baseline_output, compiled_output in zip(baseline_outputs, tvm_output): assert_shapes_match(baseline_output, compiled_output) tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) +@tvm.testing.uses_gpu +def test_forward_math(): + class MathOp(nn.Layer): + def __init__(self, op_name): + super(MathOp, self).__init__() + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, op_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, inputs): + return self.func(inputs) + + input_data = paddle.rand([1, 2, 5, 5], dtype="float32") + op_list = [ + "abs", + "acos", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", + "exp", + "floor", + "log", + "log2", + "log10", + "log1p", + "numel", + "relu", + "round", + "rsqrt", + "sigmoid", + "sign", + "rsqrt", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + ] + for op_name in op_list: + verify_model(MathOp(op_name), input_data) + + @tvm.testing.uses_gpu def test_forward_add_subtract(): input_shape = [10] @@ -124,6 +204,21 @@ def add_subtract3(inputs1, inputs2): verify_model(add_subtract3, [input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_addmm(): + @paddle.jit.to_static + def addmm(input, x, y, alpha=1, beta=1): + return paddle.addmm(input, x, y, alpha, beta) + + input_shape = [10, 10] + x_shape = [10, 3] + y_shape = [3, 10] + input_data = paddle.rand(input_shape, dtype="float32") + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(addmm, input_data=[input_data, x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_argmax(): input_shape = [1, 3, 10, 10] @@ -156,26 +251,65 @@ def forward(self, inputs): @tvm.testing.uses_gpu -def test_forward_assign(): +def test_forward_argmin(): + input_shape = [1, 3, 10, 10] + + class ArgMin(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argmin(inputs) + + class ArgMin1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=1) + + class ArgMin2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=1, keepdim=False) + + class ArgMin3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=2, keepdim=True) + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMin(), input_data=input_data) + verify_model(ArgMin1(), input_data=input_data) + verify_model(ArgMin2(), input_data=input_data) + verify_model(ArgMin3(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_argsort(): + @paddle.jit.to_static + def argsort(inputs): + return paddle.argsort(inputs) + @paddle.jit.to_static - def assign(inputs): - return paddle.assign(inputs) + def argsort2(inputs): + return paddle.argsort(inputs, axis=0, descending=True) + + input_shape = [2, 3, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(argsort, input_data) + input_data2 = np.random.randint(100, size=input_shape) + verify_model(argsort2, input_data2) + + +@tvm.testing.uses_gpu +def test_forward_assign(): + class Assign(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.assign(inputs) input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") - verify_model( - assign, - [ - input_data, - ], - ) + verify_model(Assign(), [input_data]) input_data2 = np.random.randint(100, size=input_shape) - verify_model( - assign, - [ - input_data2, - ], - ) + verify_model(Assign(), [input_data2], input_shape=[[-1, -1]]) @tvm.testing.uses_gpu @@ -227,18 +361,36 @@ def cast2(inputs, dtype="int64"): input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") * 100 - verify_model( - cast1, - [ - input_data, - ], - ) - verify_model( - cast2, - [ - input_data, - ], - ) + verify_model(cast1, [input_data]) + verify_model(cast2, [input_data]) + + +@tvm.testing.uses_gpu +def test_forward_clip(): + @paddle.jit.to_static + def clip(inputs): + return paddle.clip(inputs, min=3, max=5) + + @paddle.jit.to_static + def clip2(inputs, max_value): + return paddle.clip(inputs, max=max_value) + + @paddle.jit.to_static + def clip3(inputs, min_value): + return paddle.clip(inputs, min=min_value) + + @paddle.jit.to_static + def clip4(inputs, min_value, max_value): + return paddle.clip(inputs, min=min_value, max=max_value) + + verify_model(clip, paddle.to_tensor([[1, 2], [4, 6]], dtype="int32")) + x = np.array([[1.2, 3.5], [4.5, 6.4]]) + x1 = paddle.to_tensor(x, dtype="float32") + min_value = paddle.to_tensor(np.array([2.1]), dtype="float32") + max_value = paddle.to_tensor(np.array([4.5]), dtype="float32") + verify_model(clip2, [x1, max_value]) + verify_model(clip3, [x1, min_value]) + verify_model(clip4, [x1, min_value, max_value]) @tvm.testing.uses_gpu @@ -261,40 +413,60 @@ def concat_unsqueeze2(inputs): @tvm.testing.uses_gpu -def test_forward_cumsum(): +def test_forward_crop(): + @paddle.jit.to_static + def crop1(inputs): + return paddle.crop(inputs, shape=[2, 2]) + @paddle.jit.to_static - def cusum1(inputs): - return paddle.cumsum(inputs) + def crop2(inputs, shape): + return paddle.crop(inputs, shape=shape, offsets=[0, 1]) @paddle.jit.to_static - def cusum2(inputs): - return paddle.cumsum(inputs, axis=0) + def crop3(inputs): + offsets = paddle.to_tensor(np.array([1, 0]).astype("int32")) + return paddle.crop(inputs, shape=[3, 3], offsets=offsets) @paddle.jit.to_static - def cusum3(inputs): - return paddle.cumsum(inputs, axis=1) + def crop4(inputs, shape, offsets): + return paddle.crop(inputs, shape=shape, offsets=offsets) + + input_shape = [10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(crop1, input_data=[input_data]) + shape = paddle.to_tensor(np.array([3, 3], "int32")) + verify_model(crop2, [input_data, shape], input_shape=[[-1, -1], [2]]) + verify_model(crop3, input_data=[input_data]) + offsets = paddle.to_tensor(np.array([1, 1]).astype("int32")) + verify_model(crop4, input_data=[input_data, shape, offsets], input_shape=[[-1, -1], [2], [2]]) + + +@tvm.testing.uses_gpu +def test_forward_cumsum(): + class Cumsum1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cumsum(inputs) + + class Cumsum2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cumsum(inputs, axis=0) + + class Cumsum3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cumsum(inputs, axis=1) input_data = paddle.randint(0, 100, (10, 10), dtype=paddle.int32) - verify_model(cusum1, [input_data]) - verify_model(cusum1, [input_data.astype(paddle.int64)]) - verify_model( - cusum2, - [ - input_data, - ], - ) - verify_model( - cusum3, - [ - input_data, - ], - ) + verify_model(Cumsum1(), input_data) + verify_model(Cumsum1(), [input_data.astype(paddle.int64)]) + verify_model(Cumsum2(), input_data) + verify_model(Cumsum3(), input_data) @tvm.testing.uses_gpu def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] - class Conv2D1(nn.Layer): def __init__(self): super(Conv2D1, self).__init__() @@ -315,9 +487,74 @@ def __init__(self): def forward(self, inputs): return self.softmax(self.conv(inputs)) + class Conv2D3(nn.Layer): + def __init__(self): + super(Conv2D3, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False, padding="SAME") + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv(inputs) + + class Conv2D4(nn.Layer): + def __init__(self): + super(Conv2D4, self).__init__() + self.conv = nn.Conv2D( + 3, 6, 7, groups=3, bias_attr=False, padding=[1, 2, 0, 1], stride=2, dilation=2 + ) + + @paddle.jit.to_static + def forward(self, inputs): + return self.conv(inputs) + + conv2d_input_shape = [1, 3, 112, 112] conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") verify_model(Conv2D1(), input_data=conv2d_input_data) verify_model(Conv2D2(), input_data=conv2d_input_data) + verify_model(Conv2D3(), input_data=conv2d_input_data) + verify_model(Conv2D4(), input_data=conv2d_input_data) + verify_model(Conv2D1(), conv2d_input_data, input_shape=[[-1, 3, 112, 112]]) + + +@tvm.testing.uses_gpu +def test_forward_dist(): + @paddle.jit.to_static + def dist(x, y): + return paddle.dist(x, y, p=2) + + @paddle.jit.to_static + def dist2(x, y): + return paddle.dist(x, y, p=20) + + @paddle.jit.to_static + def dist3(x, y): + return paddle.dist(x, y, p=float("-inf")) + + @paddle.jit.to_static + def dist4(x, y): + return paddle.dist(x, y, p=float("inf")) + + x_shape = [10, 3] + y_shape = [10, 1] + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(dist, input_data=[x_data, y_data]) + verify_model(dist2, input_data=[x_data, y_data]) + verify_model(dist3, input_data=[x_data, y_data]) + verify_model(dist4, input_data=[x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_dot(): + @paddle.jit.to_static + def dot(x, y): + return paddle.dot(x, y) + + x_shape = [10, 3] + y_shape = [10, 3] + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(dot, input_data=[x_data, y_data]) @tvm.testing.uses_gpu @@ -333,35 +570,97 @@ def dropout(inputs): @tvm.testing.uses_gpu -def test_forward_shape_full(): +def test_forward_expand(): @paddle.jit.to_static - def full1(inputs): - return paddle.full(paddle.shape(inputs), 3.14) + def expand1(inputs): + return paddle.expand(inputs, shape=[2, 3]) @paddle.jit.to_static - def full2(inputs): - return paddle.full(paddle.shape(inputs), 1.0, dtype=inputs.dtype) + def expand2(inputs, shape): + return paddle.expand(inputs, shape=shape) - input_shape = [1, 3, 10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(full1, input_data=[input_data]) - verify_model(full2, input_data=[input_data]) + x_shape = [3] + x_data = paddle.rand(x_shape, dtype="float32") + verify_model(expand1, input_data=[x_data]) + shape = paddle.to_tensor(np.array([2, 3]).astype("int32")) + verify_model(expand2, [x_data, shape], input_shape=[[3], [2]]) + + +@tvm.testing.uses_gpu +def test_forward_expand_as(): + @paddle.jit.to_static + def expand_as(x, y): + z = paddle.expand_as(x, y) + z += y + return z + + data_x = paddle.to_tensor([1, 2, 3], dtype="int32") + data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32") + verify_model(expand_as, [data_x, data_y]) @tvm.testing.uses_gpu -def test_forward_ones_like(): +def test_forward_ones(): @paddle.jit.to_static - def ones_like1(inputs): - return paddle.ones_like(inputs) + def ones1(inputs): + ones = paddle.ones([1, 3, 10, 10]) + out = inputs + ones + return out @paddle.jit.to_static - def ones_like2(inputs): - return paddle.ones_like(inputs, dtype="int32") + def ones2(inputs): + shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32") + ones = paddle.ones(shape) + out = inputs + ones + return out input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(ones_like1, input_data=input_data) - verify_model(ones_like2, input_data=input_data) + verify_model(ones1, input_data=input_data) + verify_model(ones2, input_data=input_data) + + +def test_forward_elemwise(): + class ElemwiseOp(nn.Layer): + def __init__(self, op_name): + super(ElemwiseOp, self).__init__() + self.op_name_ = op_name + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, op_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, input1, input2): + y = self.func(input1, input2) + if "equal" in self.op_name_ or "than" in self.op_name_: + y = paddle.cast(y, "int32") + return y + + op_list = [ + "floor_divide", + "floor_mod", + "maximum", + "minimum", + "equal", + "greater_equal", + "greater_than", + "less_equal", + "less_than", + "not_equal", + ] + input_shape = [10, 10] + input_shape_2 = [ + 10, + ] + x_data = paddle.rand(input_shape, dtype="float32") + y_data = paddle.rand(input_shape_2, dtype="float32") + x_data_2 = paddle.randint(1, 100, input_shape_2, dtype="int32") + y_data_2 = paddle.randint(1, 100, input_shape, dtype="int32") + for op_name in op_list: + if op_name not in ["floor_divide"]: + verify_model(ElemwiseOp(op_name), [x_data, y_data]) + verify_model(ElemwiseOp(op_name), [x_data_2, y_data_2]) @tvm.testing.uses_gpu @@ -376,25 +675,71 @@ def gelu(inputs): @tvm.testing.uses_gpu -def test_forward_hard_sigmoid(): - @paddle.jit.to_static - def hard_sigmoid(inputs): - return nn.functional.hardsigmoid(inputs) +def test_forward_activation(): + class Activation(nn.Layer): + def __init__(self, op_name): + super(Activation, self).__init__() + self.op_name_ = op_name + for candidate in (paddle.nn.functional, paddle): + self.func = getattr(candidate, op_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, inputs): + return self.func(inputs) input_shape = [1, 3, 10, 10] + input_data = paddle.normal(shape=input_shape) * 10.0 + input_data_2 = paddle.normal(shape=input_shape).astype("float64") * 10.0 + op_list = [ + "elu", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "log_sigmoid", + "log_softmax", + "selu", + "sigmoid", + "softsign", + ] + for op_name in op_list: + verify_model(Activation(op_name), input_data=input_data) + verify_model(Activation(op_name), input_data=input_data_2) + + +@tvm.testing.uses_gpu +def test_forward_isfinite(): + @paddle.jit.to_static + def isfinite(inputs): + return paddle.cast(paddle.isfinite(inputs), "int32") + + input_shape = [5, 5] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(hard_sigmoid, input_data=input_data) + verify_model(isfinite, input_data=input_data) @tvm.testing.uses_gpu -def test_forward_hard_swish(): +def test_forward_isinf(): @paddle.jit.to_static - def hard_swish(inputs): - return nn.functional.hardswish(inputs) + def isinf(inputs): + return paddle.cast(paddle.isinf(inputs), "int32") - input_shape = [1, 3, 10, 10] + input_shape = [5, 5] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(isinf, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_isnan(): + @paddle.jit.to_static + def isnan(inputs): + return paddle.cast(paddle.isnan(inputs), "int32") + + input_shape = [5, 5] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(hard_swish, input_data=input_data) + verify_model(isnan, input_data=input_data) @tvm.testing.uses_gpu @@ -433,25 +778,47 @@ def leaky_relu(inputs): @tvm.testing.uses_gpu -def test_forward_look_up(): - @paddle.jit.to_static - def look_up(inputs, weight): - return nn.functional.embedding(inputs, weight) - - class LookUp(nn.Layer): - def __init__(self): - super(LookUp, self).__init__() - self.embedding = paddle.nn.Embedding(10, 4, sparse=True) +def test_forward_logical_op(): + class LogicalOp(nn.Layer): + def __init__(self, op_name, out=False): + super(LogicalOp, self).__init__() + self.out = out + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, op_name, None) + if self.func: + break @paddle.jit.to_static - def forward(self, inputs): - return self.embedding(inputs) - - input_shape = [1, 3, 10, 10] - input_data = paddle.randint(0, 10, input_shape, dtype="int32") - weight = paddle.rand([10, 4], dtype="float32") - verify_model(look_up, input_data=[input_data, weight]) - verify_model(LookUp(), input_data=input_data) + def forward(self, x, y): + if self.out: + out = paddle.to_tensor([True, True, True]) + z = self.func(x, y, out=out) + else: + z = self.func(x, y) + return paddle.cast(z, "int32") + + class LogicalOp_not(LogicalOp): + @paddle.jit.to_static + def forward(self, x): + if self.out: + out = paddle.to_tensor([True, True, True]) + z = self.func(x, out=out) + else: + z = self.func(x) + return paddle.cast(z, "int32") + + op_list = [ + "logical_or", + "logical_xor", + "logical_and", + ] + x = paddle.to_tensor([True]) + y = paddle.to_tensor([True, False, True, False]) + for op_name in op_list: + verify_model(LogicalOp(op_name, False), [x, y]) + verify_model(LogicalOp(op_name, True), [x, y]) + verify_model(LogicalOp_not("logical_not", False), [y]) + verify_model(LogicalOp_not("logical_not", True), [y]) @tvm.testing.uses_gpu @@ -504,6 +871,56 @@ def forward(self, input1, input2): verify_model(MatMul1(), input_data=[input_data1, input_data2]) +@tvm.testing.uses_gpu +def test_forward_meshgrid(): + @paddle.jit.to_static + def t(x, y, z): + return paddle.meshgrid(x, y, z) + + x = paddle.randint(low=0, high=100, shape=[2]) + y = paddle.randint(low=0, high=100, shape=[3]) + z = paddle.randint(low=0, high=100, shape=[5]) + verify_model(t, [x, y, z]) + + +def test_forward_mm(): + class Mm(nn.Layer): + def forward(self, input1, input2): + return paddle.mm(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # matrix x matrix + input_data1 = paddle.randn((5, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # batched matrix x batched matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((10, 4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + # batched matrix x broadcasted matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(Mm(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_mv(): + class Mv(nn.Layer): + def forward(self, input1, input2): + return paddle.mv(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(Mv(), input_data=[input_data1, input_data2]) + + @tvm.testing.uses_gpu def test_forward_pool2d(): @paddle.jit.to_static @@ -516,32 +933,42 @@ def pool2d2(inputs): @paddle.jit.to_static def pool2d3(inputs): - return nn.functional.max_pool2d( + output = nn.functional.max_pool2d(inputs, kernel_size=2, stride=2, padding=0) + return output + + @paddle.jit.to_static + def pool2d4(inputs): + output, max_indices = nn.functional.max_pool2d( inputs, kernel_size=2, stride=2, padding=0, return_mask=True ) + return output input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data=input_data) + verify_model(pool2d1, input_data, input_shape=[[-1, 2, 32, 32]]) verify_model(pool2d2, input_data=input_data) - # verify_model(pool2d3, input_data=input_data) + input_data1 = paddle.uniform(shape=[1, 2, 1, 50], dtype="float32", min=-1, max=1) + verify_model(pool2d3, input_data=input_data1) @tvm.testing.uses_gpu -def test_forward_relu(): - @paddle.jit.to_static - def relu(inputs): - return nn.functional.relu(inputs) +def test_forward_rank(): + class Rank(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + rank = paddle.rank(inputs) + rank = paddle.unsqueeze(rank, axis=0) + output = inputs + rank + return output - input_shape = [10, 10] + input_shape = [1, 2, 1, 3, 1] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(relu, input_data=input_data) + verify_model(Rank(), input_data=input_data) @tvm.testing.uses_gpu def test_forward_reshape(): @paddle.jit.to_static - def reshape1(inputs, x): - new_shape = paddle.shape(x) + def reshape1(inputs, new_shape): return paddle.reshape(inputs, new_shape) @paddle.jit.to_static @@ -551,7 +978,7 @@ def reshape2(inputs): @paddle.jit.to_static def reshape3(inputs): data_shape = inputs.shape - return inputs.reshape([data_shape[0] * data_shape[1], data_shape[2]]) + return inputs.reshape([data_shape[1], data_shape[2], data_shape[0]]) @paddle.jit.to_static def reshape4(inputs, x): @@ -561,7 +988,8 @@ def reshape4(inputs, x): input_shape = [2, 1, 10, 1, 10] input_data = paddle.rand(input_shape, dtype="float32") input_data2 = paddle.randn([2, 1, 10, 10]) - verify_model(reshape1, input_data=[input_data, input_data2]) + new_shape = paddle.shape(input_data2) + verify_model(reshape1, [input_data, new_shape], input_shape=[[2, 1, 10, 1, 10], [4]]) verify_model(reshape2, input_data=input_data) verify_model(reshape3, input_data=paddle.randn((2, 3, 4))) verify_model(reshape4, input_data=[input_data, input_data2]) @@ -587,11 +1015,55 @@ def scale2(inputs): verify_model(scale2, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_scatter(): + @paddle.jit.to_static + def scatter(x, index, updates): + return paddle.scatter(x, index, updates, overwrite=True) + + @paddle.jit.to_static + def scatter2(x, index, updates): + return paddle.scatter(x, index, updates, overwrite=False) + + x = paddle.rand([10, 8, 5], dtype="float32") + index = paddle.to_tensor( + [ + 2, + 1, + 0, + 6, + ] + ) + updates = paddle.rand([4, 8, 5], dtype="float32") + verify_model(scatter, [x, index, updates], input_shape=[[-1, 8, 5], [4], [4, 8, 5]]) + verify_model(scatter2, [x, index, updates]) + + +def test_forward_scatter_nd(): + @paddle.jit.to_static + def scatter_nd(index, updates): + shape = [3, 5, 9, 10] + return paddle.scatter_nd(index, updates, shape) + + @paddle.jit.to_static + def scatter_nd_add(x, index, updates): + return paddle.scatter_nd_add(x, index, updates) + + index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) + index = paddle.to_tensor(index_data) + updates = paddle.rand(shape=[3, 9, 10], dtype="float32") + verify_model(scatter_nd, [index, updates]) + x = paddle.rand(shape=[3, 5, 4, 9, 10], dtype="float32") + updates = paddle.rand(shape=[3, 2, 9, 10], dtype="float32") + index = paddle.randint(0, 3, shape=[3, 2, 3]) + verify_model(scatter_nd_add, [x, index, updates]) + + @tvm.testing.uses_gpu def test_forward_slice(): @paddle.jit.to_static - def slice1(inputs): - return inputs[:, :, :, :3] + def slice1(inputs, end): + return inputs[:, :, :, :end] @paddle.jit.to_static def slice2(inputs): @@ -607,55 +1079,100 @@ def slice4(inputs): x1 = paddle.to_tensor([3]) + paddle.to_tensor([1]) return inputs[:, x0:, 1:x1, :] + @paddle.jit.to_static + def slice5(inputs): + x0 = paddle.to_tensor([3]) + return inputs[:, 1::1, 2::x0, 4:10] + input_shape = [1, 3, 10, 10] input_data = paddle.rand(input_shape, dtype="float32") - verify_model( - slice1, - input_data=[ - input_data, - ], - ) + end = paddle.to_tensor(np.array([3])) + verify_model(slice1, [input_data, end], input_shape=[[1, 3, 10, 10], [1]]) verify_model(slice2, input_data=input_data) - # need op "strided_slice" - # verify_model(slice3, input_data=paddle.randn((4, 4))) - # need op "assign_value" - # verify_model(slice4, input_data=input_data) + verify_model(slice3, input_data=paddle.randn((4, 4))) + verify_model(slice4, input_data=input_data) + verify_model(slice5, input_data=input_data) @tvm.testing.uses_gpu -def test_forward_tanh(): +def test_forward_sort(): @paddle.jit.to_static - def tanh(inputs): - return paddle.tanh(inputs) + def sort(inputs): + return paddle.sort(inputs) - input_shape = [1, 3, 10, 10] + @paddle.jit.to_static + def sort2(inputs): + return paddle.sort(inputs, axis=0, descending=True) + + input_shape = [2, 3, 5] input_data = paddle.rand(input_shape, dtype="float32") - verify_model(tanh, input_data=input_data) + verify_model(sort, input_data) + input_data2 = np.random.randint(100, size=input_shape) + verify_model(sort2, input_data2) + + +@tvm.testing.uses_gpu +def test_forward_subtract(): + class Subtract(nn.Layer): + @paddle.jit.to_static + def forward(self, x, y): + return paddle.subtract(x, y) + + input_data1 = paddle.to_tensor([2, np.nan, 5], dtype="float32") + input_data2 = paddle.to_tensor([1, 4, np.nan], dtype="float32") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + input_data1 = paddle.randint(0, 10, (3, 4), dtype="int32") + input_data2 = paddle.randint(0, 10, (4,), dtype="int32") + verify_model(Subtract(), input_data=[input_data1, input_data2]) + + input_data1 = paddle.randint(0, 10, (10, 3, 4), dtype="int64") + input_data2 = paddle.randint(0, 10, (3, 4), dtype="int64") + verify_model(Subtract(), input_data=[input_data1, input_data2]) if __name__ == "__main__": test_forward_add_subtract() + test_forward_addmm() test_forward_argmax() + test_forward_argmin() + test_forward_argsort() test_forward_assign() test_forward_batch_norm() test_forward_cast() + test_forward_clip() test_forward_concat_unsqueeze() - test_forward_cumsum() test_forward_conv() + test_forward_crop() + test_forward_cumsum() + test_forward_dist() + test_forward_dot() test_forward_dropout() - test_forward_shape_full() - test_forward_ones_like() + test_forward_elemwise() + test_forward_expand() + test_forward_expand_as() + test_forward_ones() test_forward_gelu() - test_forward_hard_sigmoid() - test_forward_hard_swish() + test_forward_math() + test_forward_activation() + test_forward_isinf() test_forward_layer_norm() test_forward_leaky_relu() - test_forward_look_up() - test_forward_multiply() + test_forward_logical_op() + test_forward_lstm() + test_forward_gru() test_forward_matmul() + test_forward_meshgrid() + test_forward_mm() + test_forward_mv() + test_forward_multiply() test_forward_pool2d() - test_forward_relu() + test_forward_rank() test_forward_reshape() test_forward_scale() + test_forward_scatter() + test_forward_scatter_nd() test_forward_slice() - test_forward_tanh() + test_forward_sort() + test_forward_subtract() + test_forward_math()