From 73a21b050c2de14ecb0e89a105dc254aab75eca4 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 15 Oct 2021 06:28:16 +0000 Subject: [PATCH 1/8] Add autopad for conv/pool --- python/tvm/relay/frontend/paddlepaddle.py | 433 ++++++++++++++---- .../frontend/paddlepaddle/test_forward.py | 415 +++++++++++++---- 2 files changed, 683 insertions(+), 165 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 76a12691d2bf..77f28e540000 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -17,73 +17,217 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """Paddle: PArallel Distributed Deep LEarning.""" -import warnings import numpy as np import tvm from tvm.ir import IRModule +from ... import nd as _nd from .. import analysis +from .. import ty as _ty from .. import expr as _expr from .. import function as _function from .. import ty as _ty from .. import op as _op from .common import ( fold_constant, + get_relay_op, infer_shape, infer_type, infer_value, + try_infer_value, new_var, ) __all__ = ["from_paddle"] +def _autopad( + data, + strides, + kernel_shape, + dilations=[1, 1], + pad_type="constant", + pad_value=0.0, +): + """Perform padding under SAME mode for dynamic and fixed input shapes. + This implementation refers to ONNX frontend. + """ + + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + # get input shape + ndim = len(infer_shape(data)) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) + + # set up integer constants + zero = _op.const(0, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) + + +def _dtype_shape_promotion(inputs): + """Promote data type and shape for list of tensors.""" + + dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] + + ranks = [len(infer_shape(x)) for x in inputs] + if set(ranks) == set([1, 0]): + for i, r in enumerate(ranks): + if r == 0: + inputs[i] = _op.expand_dims(inputs[i], axis=0) + + dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs) + if len(dtypes) == 1: + return inputs + max_dtype = dtype_order[max(dtypes)] + for i, input_op in enumerate(inputs): + if infer_type(input_op).checked_type.dtype != max_dtype: + inputs[i] = input_op.astype(max_dtype) + return inputs + + def shape_of(x, dtype="int32"): - """Get shape of a tensor""" + """Get shape of a tensor.""" ttype = infer_type(x).checked_type if not _ty.is_dynamic(ttype): shape = list(ttype.shape) - return _expr.const(shape, dtype) + return _expr.const(np.array(shape), dtype) return _op.shape_of(x, dtype) -def _get_pad_size(in_size, dilated_kernel_size, stride_size): - """calculate the paddings size""" +def _convert_dtype_value(val): + """Converts a Paddle type id to a string.""" + + convert_dtype_map = { + 21: "int8", + 20: "uint8", + 6: "float64", + 5: "float32", + 4: "float16", + 3: "int64", + 2: "int32", + 1: "int16", + 0: "bool", + } + if val not in convert_dtype_map: + msg = "Paddle data type value %d is not handled yet." % (val) + raise NotImplementedError(msg) + return convert_dtype_map[val] - if stride_size == 1 or in_size % stride_size == 0: - pad = max(dilated_kernel_size - stride_size, 0) + +def convert_unary_op(g, op, block): + """Operator converter for all the unary operators.""" + + # op_map stores mapping relationship between paddlepaddle and relay + op_map = { + "isinf_v2": _op.isinf, + "isfinite_v2": _op.isfinite, + "isnan_v2": _op.isnan, + } + if op.type in op_map: + unary_func = op_map[op.type] else: - pad = max(dilated_kernel_size - (in_size % stride_size), 0) + # while paddle operator's name is same with relay + unary_func = get_relay_op(op.type) + out = unary_func(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) - pad_before = pad // 2 - pad_after = pad - pad_before - return [pad_before, pad_after] +def convert_binary_logical_op(g, op, block): + """Operator converter for logical op.""" + ipt0 = g.get_node(op.input("X")[0]) + ipt1 = g.get_node(op.input("Y")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0, ipt1) + g.add_node(op.output("Out")[0], out) -def convert_arg_max(g, op, block): - """Operator converter for arg_max.""" + +def convert_arg_max_min(g, op, block): + """Operator converter for arg_max and arg_min.""" axis = op.attr("axis") keepdims = op.attr("keepdims") flatten = op.attr("flatten") + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) + func = _op.argmax if op.type == "arg_max" else _op.argmin x = g.get_node(op.input("X")[0]) if axis is None or flatten: x = _op.reshape(x, [-1]) - out = _op.argmax(x, axis=None, keepdims=True) + out = func(x, axis=None, keepdims=True) else: - out = _op.argmax(x, axis=axis, keepdims=keepdims) + out = func(x, axis=axis, keepdims=keepdims) + if dtype != infer_type(out).checked_type.dtype: + out = _op.cast(out, dtype) g.add_node(op.output("Out")[0], out) +def convert_argsort(g, op, block): + """Operator converter for argsort.""" + + x = g.get_node(op.input("X")[0]) + axis = op.attr("axis") + descending = op.attr("descending") + + out_indices = _op.argsort(x, axis, not descending, dtype="int64") + out = _op.gather(x, axis, out_indices) + g.add_node(op.output("Out")[0], out) + g.add_node(op.output("Indices")[0], out_indices) + + def convert_assign(g, op, block): """Operator converter for assign.""" - out = _op.copy(g.get_node(op.input("X")[0])) + out = g.get_node(op.input("X")[0]) + g.add_node(op.output("Out")[0], out) + + +def convert_assign_value(g, op, block): + """Operator converter for assign_value.""" + + keys = ["bool_values", "fp32_values", "int32_values", "int64_values"] + dtypes = ["bool", "float32", "int32", "int64"] + for i, key in enumerate(keys): + dtype = dtypes[i] + value = np.array(op.attr(key)).astype(dtype) + if value is not None and value.size >= 1: + break + shape = op.attr("shape") + value = value.reshape(shape) + out = _op.const(value, dtype=dtype) g.add_node(op.output("Out")[0], out) @@ -110,8 +254,8 @@ def convert_batch_norm(g, op, block): def convert_cast(g, op, block): """Operator converter for cast.""" - dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] + dtype = op.attr("out_dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) out = _op.cast(x, dtype=dtype) g.add_node(op.output("Out")[0], out) @@ -122,6 +266,7 @@ def convert_concat(g, op, block): inputs = [g.get_node(op.input("X")[i]) for i in range(len(op.input("X")))] axis = op.attr("axis") + inputs = _dtype_shape_promotion(inputs) out = _op.concatenate(inputs, axis=axis) g.add_node(op.output("Out")[0], out) @@ -138,17 +283,16 @@ def convert_conv2d(g, op, block): kernel = g.get_node(op.input("Filter")[0]) input_x = g.get_node(op.input("Input")[0]) out_channels, _, k_h, k_w = infer_shape(kernel) - in_h, in_w = infer_shape(input_x)[2:] if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + dilations = [1, 1] + input_x = _autopad(input_x, strides, [k_h, k_w], dilations) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' @@ -191,7 +335,18 @@ def convert_dropout(g, op, block): """Operator converter for dropout.""" x = g.get_node(op.input("X")[0]) - out = _op.copy(x) + g.add_node(op.output("Out")[0], x) + + +def convert_dot(g, op, block): + """Operator converter for dot.""" + + # x, y should be 1D or 2D tensor + # when it's 2D tensor, the first dimension means batch dimension + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + out = _op.sum(_op.multiply(x, y), axis=[-1], keepdims=True) g.add_node(op.output("Out")[0], out) @@ -199,49 +354,61 @@ def convert_elementwise_op(g, op, block): """Operator converter for all the elementwise operators.""" op_map = { - "elementwise_div": lambda x, y: x / y, - "elementwise_add": lambda x, y: x + y, - "elementwise_mul": lambda x, y: x * y, - "elementwise_sub": lambda x, y: x - y, - "elementwise_mod": lambda x, y: x % y, + "elementwise_div": "divide", + "elementwise_add": "add", + "elementwise_mul": "multiply", + "elementwise_sub": "subtract", + "elementwise_mod": "mod", + "elementwise_max": "maximum", + "elementwise_min": "minimum", + "elementwise_pow": "power", + "elementwise_floordiv": "floor_divide", + "equal": "equal", + "greater_equal": "greater_equal", + "greater_than": "greater", + "less_equal": "less_equal", + "less_than": "less", + "not_equal": "not_equal", } op_func = op_map[op.type] ipt0 = g.get_node(op.input("X")[0]) ipt1 = g.get_node(op.input("Y")[0]) - ipt0_shape = block.var(op.input("X")[0]).shape - ipt1_shape = block.var(op.input("Y")[0]).shape + ipt0_shape = infer_shape(ipt0) + ipt1_shape = infer_shape(ipt1) axis = op.attr("axis") if len(ipt0_shape) != len(ipt1_shape): if axis < 0: axis = axis + len(ipt0_shape) if axis != len(ipt0_shape) - 1: ipt1 = _op.expand_dims(ipt1, axis=axis, num_newaxis=(len(ipt0_shape) - axis - 1)) + op_func = get_relay_op(op_func) out = op_func(ipt0, ipt1) g.add_node(op.output("Out")[0], out) -def convert_equal(g, op, block): - """Operator converter for equal.""" +def convert_expand(g, op, block): + """Operator converter for expand.""" x = g.get_node(op.input("X")[0]) - y = g.get_node(op.input("Y")[0]) - out = _op.equal(x, y) + if op.input("Shape"): + sizes = g.get_node(op.input("Shape")[0]) + sizes = try_infer_value(sizes, g.get_params())[0] + else: + sizes = op.attr("shape") + + if isinstance(sizes, np.ndarray): + sizes = sizes.tolist() + + out = _op.broadcast_to(x, sizes) g.add_node(op.output("Out")[0], out) -def convert_activation(g, op, block): - """Operator converter for all the activation.""" +def convert_expand_as(g, op, block): + """Operator converter for expand_as.""" - op_map = { - "exp": _op.exp, - "relu": _op.nn.relu, - "tanh": _op.tanh, - "sqrt": _op.sqrt, - "erf": _op.erf, - "abs": _op.abs, - } - act_func = op_map[op.type] - out = act_func(g.get_node(op.input("X")[0])) + x = g.get_node(op.input("X")[0]) + target_shape = op.attr("target_shape") + out = _op.broadcast_to(x, target_shape) g.add_node(op.output("Out")[0], out) @@ -259,6 +426,12 @@ def convert_feed(g, op, block): ipt_name = op.name if g.shape_dict is not None: ipt_shape = g.shape_dict[ipt_name] + + if isinstance(ipt_shape, tuple): + ipt_shape = list(ipt_shape) + for i, s in enumerate(ipt_shape): + if s < 0: + ipt_shape[i] = _ty.Any() out = new_var(ipt_name, shape=ipt_shape, dtype=ipt_dtype) g.add_node(ipt_name, out) @@ -266,18 +439,11 @@ def convert_feed(g, op, block): def convert_fill_any_like(g, op, block): """Operator converter for fill_any_like.""" - out_name = op.output("Out")[0] - out_dtype = block.var(out_name).dtype - out_dtype = str(out_dtype).strip().split(".")[1] + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) x = g.get_node(op.input("X")[0]) - ipt_type = infer_type(x).checked_type - value = op.attr("value") - if not _ty.is_dynamic(ipt_type): - shape = infer_shape(x) - const = np.ones(shape) * value - out = _expr.const(const.astype(out_dtype)) - else: - out = _op.transform.full_like(x, value).astype(out_dtype) + value = _expr.const(op.attr("value"), dtype=dtype) + out = _op.transform.full_like(x, value).astype(dtype) g.add_node(op.output("Out")[0], out) @@ -286,16 +452,20 @@ def convert_fill_constant(g, op, block): value = op.attr("value") shape = block.var(op.output("Out")[0]).shape - dtype = block.var(op.output("Out")[0]).dtype - dtype = str(dtype).strip().split(".")[1] - if op.input("ValueTensor"): + dtype = op.attr("dtype") + dtype = _convert_dtype_value(dtype) + value = _expr.const(value).astype(dtype) + if "ValueTensor" in op.input_names and op.input("ValueTensor"): shape = g.get_node(op.input("ValueTensor")[0]) - shape = infer_value(shape, g.get_params()).numpy() - if op.input("ShapeTensor"): + shape = try_infer_value(shape, g.get_params())[0] + if "ShapeTensor" in op.input_names and op.input("ShapeTensor"): shape = g.get_node(op.input("ShapeTensor")[0]) - shape = infer_value(shape, g.get_params()).numpy() - value = np.full(shape, value, dtype) - out = _expr.const(value.astype(dtype)).astype(dtype) + shape = try_infer_value(shape, g.get_params())[0] + + if isinstance(shape, np.ndarray): + shape = shape.tolist() + + out = _op.full(value, shape=shape, dtype=dtype) g.add_node(op.output("Out")[0], out) @@ -543,6 +713,39 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_padding(g, op, block): + """Operator converter for padding.""" + + input_x = g.get_node(op.input("X")[0]) + input_padding = op.input("Paddings") + if input_padding: + padding = g.get_node(input_padding[0]) + padding = infer_value(padding, g.get_params()).numpy().tolist() + else: + padding = op.attr("paddings") + padding = op.attr("paddings") + value = op.attr("value") + data_format = op.attr("data_format") + mode = op.attr("mode") + assert mode != "circular", "Don't support mod='circular' for PaddlePaddle's padding" + if mode == "replicate": + mode = "edge" + + pad_len = len(padding) + new_paddings = [0] * (pad_len + 4) + for i in range(0, pad_len, 2): + index = -1 - i + if data_format[:2] != "NC": + index = -3 - i + new_paddings[index] = padding[i + 1] + new_paddings[index - 1] = padding[i] + + new_paddings = [new_paddings[i : i + 2] for i in range(0, len(new_paddings), 2)] + + out = _op.nn.pad(input_x, new_paddings, pad_value=value, pad_mode=mode) + g.add_node(op.output("Out")[0], out) + + def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -553,17 +756,19 @@ def convert_pool2d(g, op, block): paddings = op.attr("paddings") padding_algorithm = op.attr("padding_algorithm") pooling_type = op.attr("pooling_type") + if global_pooling: adaptive = True ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - in_h, in_w = infer_shape(input_x)[2:] + _, _, in_h, in_w = infer_shape(input_x) op_map = { "avg": "avg_pool2d", "max": "max_pool2d", } + strides = op.attr("strides") if isinstance(strides, int): strides = [strides, strides] @@ -575,22 +780,37 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + input_x = _autopad(input_x, strides, ksize) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + ksize[0] = in_h + if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + ksize[1] = in_w + if not adaptive: - out = getattr(_op.nn, op_map[pooling_type])( - input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode - ) + if pooling_type == "avg": + exclusive = op.attr("exclusive") + out = _op.nn.avg_pool2d( + input_x, + pool_size=ksize, + strides=strides, + padding=paddings, + ceil_mode=ceil_mode, + count_include_pad=not exclusive, + ) + else: + out = getattr(_op.nn, op_map[pooling_type])( + input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode + ) else: out = getattr(_op.nn, "adaptive_" + op_map[pooling_type])(input_x, output_size=ksize) g.add_node(op.output("Out")[0], out) @@ -711,6 +931,17 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_squeeze(g, op, block): + """Operator converter for squeeze2.""" + + x = g.get_node(op.input("X")[0]) + axes = op.attr("axes") + if not axes: + axes = None + x = _op.squeeze(x, axis=axes) + g.add_node(op.output("Out")[0], x) + + def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -722,41 +953,55 @@ def convert_unsqueeze(g, op, block): _convert_map = { - "arg_max": convert_arg_max, + "arg_max": convert_arg_max_min, + "arg_min": convert_arg_max_min, + "argsort": convert_argsort, "assign": convert_assign, + "assign_value": convert_assign_value, "batch_norm": convert_batch_norm, "cast": convert_cast, "concat": convert_concat, "conv2d": convert_conv2d, "cumsum": convert_cumsum, "depthwise_conv2d": convert_conv2d, + "dot": convert_dot, "dropout": convert_dropout, "elementwise_add": convert_elementwise_op, "elementwise_div": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, - "equal": convert_equal, - "exp": convert_activation, + "equal": convert_elementwise_op, + "exp": convert_unary_op, + "expand_v2": convert_expand, + "expand_as_v2": convert_expand_as, "feed": convert_feed, "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, "gelu": convert_gelu, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, + "isfinite_v2": convert_unary_op, + "isinf_v2": convert_unary_op, + "isnan_v2": convert_unary_op, "layer_norm": convert_layer_norm, "leaky_relu": convert_leaky_relu, + "logical_and": convert_binary_logical_op, + "logical_or": convert_binary_logical_op, + "logical_xor": convert_binary_logical_op, "lookup_table_v2": convert_lookup_table, "matmul": convert_matmul, "matmul_v2": convert_matmul, "mul": convert_mul, + "pad3d": convert_padding, "pool2d": convert_pool2d, - "relu": convert_activation, + "relu": convert_unary_op, "reshape2": convert_reshape, "scale": convert_scale, "shape": convert_shape, "slice": convert_slice, "softmax": convert_softmax, - "tanh": convert_activation, + "squeeze2": convert_squeeze, + "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } @@ -781,7 +1026,7 @@ def add_node(self, name, node): self.nodes[name] = fold_constant(node) def get_params(self, name=None): - """get params from graph""" + """Get params from graph.""" if name is None: return self.params @@ -800,10 +1045,12 @@ def extract_parameters(self, program, scope=None): if not var.persistable: continue if isinstance(scope, dict): - self.params[name] = scope[name] + self.params[name] = _nd.array(scope[name]) else: - self.params[name] = np.array(scope.var(name).get_tensor()) - self.nodes[name] = _expr.const(self.params[name]) + self.params[name] = _nd.array(np.array(scope.var(name).get_tensor())) + shape = self.params[name].shape + dtype = self.params[name].dtype + self.nodes[name] = new_var(name, shape=shape, dtype=dtype) def check_input_shape(self, op, block): """Check the shape information of model's inputs, fixed shape is recommended.""" @@ -894,14 +1141,32 @@ def from_translated_layer(self, layer, shape_dict): free_vars = analysis.free_vars(outputs) func = _function.Function(free_vars, outputs) mod = IRModule.from_expr(func) + # remove unused parameters + final_params = dict() + for var in free_vars: + if var.name_hint in self.params: + final_params[var.name_hint] = self.params[var.name_hint] + self.params = final_params return mod, self.params def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. - PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. + Parameters + ---------- + program_or_layer : object of `paddle.static.Program` or `paddle.jit.TranslatedLayer` + Loaded model by `paddle.static.load_inference_model` or `paddle.jit.load` + shape_dict : dict of str to tuple/list, optional + The input shape of model + scope : object of `paddle.static.Scope`, optional + The scope that saves all the weights of model, use `paddle.static.global_scope` by default + Returns + ------- + mod : tvm.IRModule + The relay module for compilation + params : dict of str to tvm.nd.NDArray """ import paddle diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index db07e07f9d83..b274d178c9c2 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -24,6 +24,7 @@ import tvm.topi.testing from tvm import relay from tvm.contrib import graph_executor +import pytest import paddle import paddle.nn as nn @@ -79,11 +80,11 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5): baseline_outputs = (baseline_outputs.numpy(),) mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) - parms_num = min(len(input_names), len(mod["main"].params)) compiled_names = [] - for arg in mod["main"].params[:parms_num]: - assert arg.name_hint in input_names - compiled_names.append(arg.name_hint) + for arg in mod["main"].params: + assert arg.name_hint in input_names or arg.name_hint in params + if arg.name_hint in input_names: + compiled_names.append(arg.name_hint) with tvm.transform.PassContext(opt_level=3): for target, dev in tvm.testing.enabled_targets(): @@ -125,9 +126,7 @@ def add_subtract3(inputs1, inputs2): @tvm.testing.uses_gpu -def test_forward_argmax(): - input_shape = [1, 3, 10, 10] - +def test_forward_arg_max_min(): class ArgMax(nn.Layer): @paddle.jit.to_static def forward(self, inputs): @@ -148,11 +147,70 @@ class ArgMax3(nn.Layer): def forward(self, inputs): return inputs.argmax(axis=2, keepdim=True) - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(ArgMax(), input_data=input_data) - verify_model(ArgMax1(), input_data=input_data) - verify_model(ArgMax2(), input_data=input_data) - verify_model(ArgMax3(), input_data=input_data) + class ArgMin(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argmin(inputs) + + class ArgMin1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=1) + + class ArgMin2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1, keepdim=False) + + class ArgMin3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmin(axis=2, keepdim=True) + + input_shapes = [[256], [5, 28], [10, 5, 4], [1, 3, 8, 8]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax(), input_data=input_data) + verify_model(ArgMin(), input_data=input_data) + for input_shape in input_shapes[1:]: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax1(), input_data=input_data) + verify_model(ArgMax2(), input_data=input_data) + verify_model(ArgMin1(), input_data=input_data) + verify_model(ArgMin2(), input_data=input_data) + for input_shape in input_shapes[2:]: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax3(), input_data=input_data) + verify_model(ArgMin3(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_argsort(): + class ArgSort1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argsort(inputs) + + class ArgSort2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argsort(inputs, axis=0, descending=True) + + class ArgSort3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argsort(inputs, axis=-1, descending=True) + + input_shapes = [[256], [10, 20], [10, 5, 3], [1, 3, 5, 5]] + for input_shape in input_shapes: + # Avoid duplicate elements in the array which will bring + # different results with different sort algorithms + np.random.seed(13) + np_data = np.random.choice(range(-5000, 5000), np.prod(input_shape), replace=False) + input_data = paddle.to_tensor(np_data.reshape(input_shape).astype("int64")) + verify_model(ArgSort1(), [input_data]) + verify_model(ArgSort2(), [input_data]) + verify_model(ArgSort3(), [input_data]) @tvm.testing.uses_gpu @@ -161,6 +219,11 @@ def test_forward_assign(): def assign(inputs): return paddle.assign(inputs) + @paddle.jit.to_static + def assign_value(inputs): + x = paddle.to_tensor(np.array([3]).astype("float32")) + return inputs + x + input_shape = [2, 3] input_data = paddle.rand(input_shape, dtype="float32") verify_model( @@ -176,6 +239,7 @@ def assign(inputs): input_data2, ], ) + verify_model(assign_value, [input_data]) @tvm.testing.uses_gpu @@ -241,6 +305,31 @@ def cast2(inputs, dtype="int64"): ) +@tvm.testing.uses_gpu +def test_forward_check_tensor(): + class IsFinite(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cast(paddle.isfinite(inputs), "int32") + + class IsNan(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cast(paddle.isnan(inputs), "int32") + + class IsInf(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.cast(paddle.isinf(inputs), "int32") + + input_shapes = [[32], [8, 32], [2, 5, 20], [2, 3, 8, 8], [2, 2, 3, 6, 6]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(IsFinite(), input_data=input_data) + verify_model(IsNan(), input_data=input_data) + verify_model(IsInf(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_concat_unsqueeze(): @paddle.jit.to_static @@ -293,31 +382,51 @@ def cusum3(inputs): @tvm.testing.uses_gpu def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] - class Conv2D1(nn.Layer): - def __init__(self): + def __init__(self, stride=1, padding=0, dilation=1, groups=1, padding_mode="zeros"): super(Conv2D1, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.conv = nn.Conv2D( + 3, + 6, + 3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) self.softmax = nn.Softmax() @paddle.jit.to_static def forward(self, inputs): return self.softmax(self.conv(inputs)) - class Conv2D2(nn.Layer): - def __init__(self): - super(Conv2D2, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) - self.softmax = nn.Softmax() + input_shapes = [[1, 3, 10, 10], [1, 3, 12, 12]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="VALID", dilation=3), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=3), input_data=input_data) + verify_model( + Conv2D1(stride=2, padding=3, dilation=3, padding_mode="replicate"), + input_data=input_data, + ) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=2, groups=3), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_dot(): + class Dot(nn.Layer): @paddle.jit.to_static - def forward(self, inputs): - return self.softmax(self.conv(inputs)) + def forward(self, x, y): + return paddle.dot(x, y) - conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") - verify_model(Conv2D1(), input_data=conv2d_input_data) - verify_model(Conv2D2(), input_data=conv2d_input_data) + input_shapes = [[128], [8, 24]] + for input_shape in input_shapes: + x_data = paddle.rand(input_shape, dtype="float32") + y_data = paddle.rand(input_shape, dtype="float32") + verify_model(Dot(), input_data=[x_data, y_data]) @tvm.testing.uses_gpu @@ -332,6 +441,93 @@ def dropout(inputs): verify_model(dropout, input_data=input_data) +def test_forward_elemwise(): + class ElemwiseAPI(nn.Layer): + def __init__(self, api_name): + super(ElemwiseAPI, self).__init__() + self.api_name_ = api_name + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, api_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, input1, input2): + y = self.func(input1, input2) + if "equal" in self.api_name_ or "than" in self.api_name_: + # for compare operation, cast boolean result to int32 + y = paddle.cast(y, "int32") + return y + + api_list = [ + "equal", + ] + x_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] + y_shapes = [[1], [8, 20], [4, 1, 1], [2, 3, 8, 8], [2, 3, 3, 9, 1]] + for x_shape, y_shape in zip(x_shapes, y_shapes): + x_data = paddle.randint(1, 1000, x_shape, dtype="int32") + y_data = paddle.randint(1, 1000, y_shape, dtype="int32") + for api_name in api_list: + verify_model(ElemwiseAPI(api_name), [x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_expand(): + @paddle.jit.to_static + def expand1(inputs): + return paddle.expand(inputs, shape=[2, 128]) + + @paddle.jit.to_static + def expand2(inputs): + return paddle.expand(inputs, shape=[2, 1, 4, 16]) + + @paddle.jit.to_static + def expand3(inputs): + return paddle.expand(inputs, shape=[2, 1, 3, 7, 7]) + + @paddle.jit.to_static + def expand4(inputs): + shape = paddle.to_tensor(np.array([2, 128]).astype("int32")) + return paddle.expand(inputs, shape=shape) + + @paddle.jit.to_static + def expand5(inputs): + shape = paddle.to_tensor(np.array([2, 1, 4, 16]).astype("int32")) + return paddle.expand(inputs, shape=shape) + + @paddle.jit.to_static + def expand6(inputs): + shape = paddle.to_tensor(np.array([2, 1, 3, 7, 7]).astype("int32")) + return paddle.expand(inputs, shape=shape) + + data = paddle.rand([128], dtype="float32") + verify_model(expand1, input_data=[data]) + verify_model(expand4, input_data=[data]) + data = paddle.rand([4, 16], dtype="float32") + verify_model(expand2, input_data=[data]) + verify_model(expand5, input_data=[data]) + data = paddle.rand([1, 3, 7, 7], dtype="float32") + verify_model(expand3, input_data=[data]) + verify_model(expand6, input_data=[data]) + + +@tvm.testing.uses_gpu +def test_forward_expand_as(): + class ExpandAs(nn.Layer): + @paddle.jit.to_static + def forward(self, x, y): + z = paddle.expand_as(x, y) + z += y + return z + + x_shapes = [[1], [8, 128], [8, 1, 1], [2, 3, 229, 229], [2, 3, 3, 224, 1]] + y_shapes = [[128], [8, 128], [8, 200, 300], [2, 3, 229, 229], [2, 3, 3, 224, 224]] + for x_shape, y_shape in zip(x_shapes, y_shapes): + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.rand(y_shape, dtype="float32") + verify_model(ExpandAs(), [x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_shape_full(): @paddle.jit.to_static @@ -348,6 +544,26 @@ def full2(inputs): verify_model(full2, input_data=[input_data]) +@tvm.testing.uses_gpu +def test_forward_squeeze(): + class Squeeze(nn.Layer): + def __init__(self, axis=None): + super(Squeeze, self).__init__() + self.axis = axis + + @paddle.jit.to_static + def forward(self, inputs): + return paddle.squeeze(inputs, axis=self.axis) + + input_shapes = [[1, 1, 3, 1, 5], [5, 1, 6]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Squeeze(axis=None), input_data=input_data) + verify_model(Squeeze(axis=1), input_data=input_data) + input_data = paddle.rand([1], dtype="float32") + verify_model(Squeeze(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_ones_like(): @paddle.jit.to_static @@ -432,6 +648,32 @@ def leaky_relu(inputs): verify_model(leaky_relu, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_logical_api(): + class LogicalAPI(nn.Layer): + def __init__(self, api_name): + super(LogicalAPI, self).__init__() + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, api_name, None) + if self.func: + break + + @paddle.jit.to_static + def forward(self, x, y): + out = paddle.to_tensor([True, True, True]) + z = self.func(x, y, out=out) + return paddle.cast(z, "int32") + + x_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] + y_shapes = [[1], [8, 20], [4, 1, 1], [2, 3, 8, 8], [2, 3, 3, 9, 1]] + for x_shape, y_shape in zip(x_shapes, y_shapes): + x_data = paddle.randint(0, 2, x_shape).astype("bool") + y_data = paddle.randint(0, 2, y_shape).astype("bool") + verify_model(LogicalAPI("logical_and"), [x_data, y_data]) + verify_model(LogicalAPI("logical_or"), [x_data, y_data]) + verify_model(LogicalAPI("logical_xor"), [x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_look_up(): @paddle.jit.to_static @@ -506,35 +748,55 @@ def forward(self, input1, input2): @tvm.testing.uses_gpu def test_forward_pool2d(): - @paddle.jit.to_static - def pool2d1(inputs): - return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) - - @paddle.jit.to_static - def pool2d2(inputs): - return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + class Pool2D1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) - @paddle.jit.to_static - def pool2d3(inputs): - return nn.functional.max_pool2d( - inputs, kernel_size=2, stride=2, padding=0, return_mask=True - ) + class Pool2D2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) - input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data=input_data) - verify_model(pool2d2, input_data=input_data) - # verify_model(pool2d3, input_data=input_data) + class Pool2D3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d( + inputs, + kernel_size=3, + stride=1, + padding=[1, 1], + exclusive=False, + divisor_override=2.5, + ) + + input_shapes = [[1, 2, 8, 8], [1, 3, 10, 10]] + for input_shape in input_shapes: + input_data = paddle.uniform(shape=input_shape, dtype="float32", min=-1, max=1) + verify_model(Pool2D1(), input_data=input_data) + verify_model(Pool2D2(), input_data=input_data) + verify_model(Pool2D3(), input_data=input_data) @tvm.testing.uses_gpu -def test_forward_relu(): - @paddle.jit.to_static - def relu(inputs): - return nn.functional.relu(inputs) +def test_forward_pad3d(): + class Pad3D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCDHW"): + super(Pad3D, self).__init__() + self.pad3d = paddle.nn.Pad3D(padding, mode=mode, value=value, data_format=data_format) - input_shape = [10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(relu, input_data=input_data) + @paddle.jit.to_static + def forward(self, inputs): + return self.pad3d(inputs) + + input_shapes = [[1, 2, 2, 5, 5], [1, 2, 2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad3D(padding=2), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1]), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], value=0.3), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], mode="reflect"), input_data=input_data) + verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) @tvm.testing.uses_gpu @@ -623,39 +885,30 @@ def slice4(inputs): @tvm.testing.uses_gpu -def test_forward_tanh(): - @paddle.jit.to_static - def tanh(inputs): - return paddle.tanh(inputs) +def test_forward_math_api(): + class MathAPI(nn.Layer): + def __init__(self, api_name): + super(MathAPI, self).__init__() + for candidate in (paddle, paddle.nn.functional): + self.func = getattr(candidate, api_name, None) + if self.func: + break - input_shape = [1, 3, 10, 10] - input_data = paddle.rand(input_shape, dtype="float32") - verify_model(tanh, input_data=input_data) + @paddle.jit.to_static + def forward(self, inputs): + return self.func(inputs) + + api_list = [ + "exp", + "relu", + "tanh", + ] + input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + for api_name in api_list: + verify_model(MathAPI(api_name), input_data=input_data) if __name__ == "__main__": - test_forward_add_subtract() - test_forward_argmax() - test_forward_assign() - test_forward_batch_norm() - test_forward_cast() - test_forward_concat_unsqueeze() - test_forward_cumsum() - test_forward_conv() - test_forward_dropout() - test_forward_shape_full() - test_forward_ones_like() - test_forward_gelu() - test_forward_hard_sigmoid() - test_forward_hard_swish() - test_forward_layer_norm() - test_forward_leaky_relu() - test_forward_look_up() - test_forward_multiply() - test_forward_matmul() - test_forward_pool2d() - test_forward_relu() - test_forward_reshape() - test_forward_scale() - test_forward_slice() - test_forward_tanh() + pytest.main([__file__]) From 30a9922eb5c70b82a5b4c4a669a3c1cdafecb5cf Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 15 Oct 2021 06:35:58 +0000 Subject: [PATCH 2/8] add autopad for conv/pool --- python/tvm/relay/frontend/paddlepaddle.py | 85 ------------------- .../frontend/paddlepaddle/test_forward.py | 17 ---- 2 files changed, 102 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index c4aae8000cb6..77f28e540000 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -43,7 +43,6 @@ __all__ = ["from_paddle"] -<<<<<<< HEAD def _autopad( data, strides, @@ -94,20 +93,6 @@ def _autopad( if isinstance(pad_value, (float, int)): pad_value = _op.const(pad_value) return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) -======= -def _get_pad_size(in_size, dilated_kernel_size, stride_size): - """Calculate the paddings size for Conv/Pool in SAME padding mode.""" - - if stride_size == 1 or in_size % stride_size == 0: - pad = max(dilated_kernel_size - stride_size, 0) - else: - pad = max(dilated_kernel_size - (in_size % stride_size), 0) - - pad_before = pad // 2 - pad_after = pad - pad_before - - return [pad_before, pad_after] ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 def _dtype_shape_promotion(inputs): @@ -143,7 +128,6 @@ def shape_of(x, dtype="int32"): def _convert_dtype_value(val): """Converts a Paddle type id to a string.""" -<<<<<<< HEAD convert_dtype_map = { 21: "int8", @@ -161,25 +145,6 @@ def _convert_dtype_value(val): raise NotImplementedError(msg) return convert_dtype_map[val] -======= - - convert_dtype_map = { - 21: "int8", - 20: "uint8", - 6: "float64", - 5: "float32", - 4: "float16", - 3: "int64", - 2: "int32", - 1: "int16", - 0: "bool", - } - if val not in convert_dtype_map: - msg = "Paddle data type value %d is not handled yet." % (val) - raise NotImplementedError(msg) - return convert_dtype_map[val] - ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 def convert_unary_op(g, op, block): """Operator converter for all the unary operators.""" @@ -201,26 +166,14 @@ def convert_unary_op(g, op, block): def convert_binary_logical_op(g, op, block): """Operator converter for logical op.""" -<<<<<<< HEAD -======= ipt0 = g.get_node(op.input("X")[0]) ipt1 = g.get_node(op.input("Y")[0]) op_func = get_relay_op(op.type) out = op_func(ipt0, ipt1) g.add_node(op.output("Out")[0], out) ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 - ipt0 = g.get_node(op.input("X")[0]) - ipt1 = g.get_node(op.input("Y")[0]) - op_func = get_relay_op(op.type) - out = op_func(ipt0, ipt1) - g.add_node(op.output("Out")[0], out) -<<<<<<< HEAD - -======= ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 def convert_arg_max_min(g, op, block): """Operator converter for arg_max and arg_min.""" @@ -333,26 +286,9 @@ def convert_conv2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": -<<<<<<< HEAD dilations = [1, 1] input_x = _autopad(input_x, strides, [k_h, k_w], dilations) paddings = [0, 0] -======= - if strides[0] == 1 and strides[1] == 1: - pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) - else: - input_shape = shape_of(input_x) - h_w = _op.strided_slice(input_shape, [2], [4]) - try: - in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() - except Exception as e: - msg = "Dynamic shape is not supported in SAME padding algorithm while stride!=1" - raise tvm.error.OpAttributeInvalid(msg) from e - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] @@ -1064,10 +1000,7 @@ def convert_unsqueeze(g, op, block): "shape": convert_shape, "slice": convert_slice, "softmax": convert_softmax, -<<<<<<< HEAD "squeeze2": convert_squeeze, -======= ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 "tanh": convert_unary_op, "unsqueeze2": convert_unsqueeze, } @@ -1221,36 +1154,18 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. -<<<<<<< HEAD -======= - ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 Parameters ---------- program_or_layer : object of `paddle.static.Program` or `paddle.jit.TranslatedLayer` Loaded model by `paddle.static.load_inference_model` or `paddle.jit.load` -<<<<<<< HEAD shape_dict : dict of str to tuple/list, optional The input shape of model scope : object of `paddle.static.Scope`, optional The scope that saves all the weights of model, use `paddle.static.global_scope` by default -======= - - shape_dict : dict of str to tuple/list, optional - The input shape of model - - scope : object of `paddle.static.Scope`, optional - The scope that saves all the weights of model, use `paddle.static.global_scope` by default - ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 Returns ------- mod : tvm.IRModule The relay module for compilation -<<<<<<< HEAD -======= - ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 params : dict of str to tvm.nd.NDArray """ diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 77c71d487628..b274d178c9c2 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -429,20 +429,6 @@ def forward(self, x, y): verify_model(Dot(), input_data=[x_data, y_data]) -@tvm.testing.uses_gpu -def test_forward_dot(): - class Dot(nn.Layer): - @paddle.jit.to_static - def forward(self, x, y): - return paddle.dot(x, y) - - input_shapes = [[128], [8, 24]] - for input_shape in input_shapes: - x_data = paddle.rand(input_shape, dtype="float32") - y_data = paddle.rand(input_shape, dtype="float32") - verify_model(Dot(), input_data=[x_data, y_data]) - - @tvm.testing.uses_gpu def test_forward_dropout(): @paddle.jit.to_static @@ -793,7 +779,6 @@ def forward(self, inputs): @tvm.testing.uses_gpu -<<<<<<< HEAD def test_forward_pad3d(): class Pad3D(nn.Layer): def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCDHW"): @@ -815,8 +800,6 @@ def forward(self, inputs): @tvm.testing.uses_gpu -======= ->>>>>>> a0a33fb0da7ed36fe743c7e98df8db30d2e87125 def test_forward_reshape(): @paddle.jit.to_static def reshape1(inputs, x): From 7c2afce0241706150096b85f5498f1ce68d0cd9e Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Sat, 16 Oct 2021 08:25:36 +0000 Subject: [PATCH 3/8] fix pylint warning --- python/tvm/relay/frontend/paddlepaddle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 77f28e540000..492ff76bab5d 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -18,6 +18,7 @@ # pylint: disable=import-outside-toplevel """Paddle: PArallel Distributed Deep LEarning.""" +import warnings import numpy as np import tvm @@ -47,7 +48,7 @@ def _autopad( data, strides, kernel_shape, - dilations=[1, 1], + dilations=(1, 1), pad_type="constant", pad_value=0.0, ): From cd8296add0d6c3a8a1409faaab9633cdadf3ae1c Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Sat, 16 Oct 2021 08:28:12 +0000 Subject: [PATCH 4/8] add some annotations --- python/tvm/relay/frontend/paddlepaddle.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 492ff76bab5d..7a72d05a7982 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -792,6 +792,9 @@ def convert_pool2d(g, op, block): msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + # handle with special case + # while kernel size less than input size + # shrink kernel size to input size if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: ksize[0] = in_h if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: From 4ebcdebed283b37dcfeaf9f91cf64301f85e4a68 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Sat, 16 Oct 2021 08:29:22 +0000 Subject: [PATCH 5/8] add som annotations --- python/tvm/relay/frontend/paddlepaddle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 7a72d05a7982..91f1ee1d3da0 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1158,6 +1158,7 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. + Parameters ---------- program_or_layer : object of `paddle.static.Program` or `paddle.jit.TranslatedLayer` @@ -1166,6 +1167,7 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): The input shape of model scope : object of `paddle.static.Scope`, optional The scope that saves all the weights of model, use `paddle.static.global_scope` by default + Returns ------- mod : tvm.IRModule From ab3278a036a93a87622d3f0bf8d32f51a2baa1c0 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Sat, 16 Oct 2021 08:29:56 +0000 Subject: [PATCH 6/8] add som annotations --- python/tvm/relay/frontend/paddlepaddle.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 91f1ee1d3da0..70a46c966b84 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1163,8 +1163,10 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): ---------- program_or_layer : object of `paddle.static.Program` or `paddle.jit.TranslatedLayer` Loaded model by `paddle.static.load_inference_model` or `paddle.jit.load` + shape_dict : dict of str to tuple/list, optional The input shape of model + scope : object of `paddle.static.Scope`, optional The scope that saves all the weights of model, use `paddle.static.global_scope` by default @@ -1172,6 +1174,7 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): ------- mod : tvm.IRModule The relay module for compilation + params : dict of str to tvm.nd.NDArray """ From fee1886a6f3924b2df3eab9c978e46ed2aa71241 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Fri, 22 Oct 2021 17:24:52 +0800 Subject: [PATCH 7/8] Refactor autopad in the onnx.py and paddlepaddle.py to relay/frontend/common.py --- python/tvm/relay/frontend/common.py | 74 +++++++++++++++++++++ python/tvm/relay/frontend/onnx.py | 79 +---------------------- python/tvm/relay/frontend/paddlepaddle.py | 78 +++------------------- 3 files changed, 85 insertions(+), 146 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 825a586918f8..cf579923e301 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -28,6 +28,7 @@ from .. import function as _function from .. import transform as _transform from .. import op as _op +from .. import ty as _ty from .. import analysis # pylint: disable=invalid-name @@ -594,6 +595,16 @@ def try_infer_value(val, on_success=None, on_failure=None): return val, False +def shape_of(x, dtype="int64"): + """Get shape of a tensor.""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) @@ -837,6 +848,69 @@ def lstm_cell( return outputs_list, hidden_state, cell_state +def autopad( + data, + strides, + kernel_shape, + dilations=(1, 1), + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + # get input shape + ndim = len(infer_shape(data)) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _op.concatenate( + [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) + + def ensure_scalar_shape(x): """ Assume that `x` is a tensor with one element (regardless of tensor rank). diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5c112c7dfce0..3c88f659f6f0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -38,6 +38,7 @@ from .. import ty as _ty from .. import vision as _vision from .common import ( + autopad, AttrCvt, Renamer, ensure_scalar_shape, @@ -51,6 +52,7 @@ infer_value, lstm_cell, new_var, + shape_of, try_resolve_var_to_const, unbind, ) @@ -315,7 +317,6 @@ def _run_calculation(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], [1] * ndim, - ndim, pad_value=pad_val, mode=attr["auto_pad"], ) @@ -411,69 +412,6 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name="instance_norm")(inputs, attr, params) -def autopad( - data, - strides, - kernel_shape, - dilations, - ndim, - pad_type="constant", - deconv=False, - mode="SAME_UPPER", - pad_value=0.0, -): - """ - Perform autopadding with dynamic input shapes - """ - # get attributes as constants - strides = _op.const(np.array(strides), dtype="int64") - dilated_kernel_shape = _op.const( - np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ), - dtype="int64", - ) - # get input shape - shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) - - # set up integer constants - zero = _op.const(0, dtype="int64") - one = _op.const(1, dtype="int64") - two = _op.const(2, dtype="int64") - - # Calculate total padding - mod = _op.mod(shape, strides) - - left = _op.maximum(dilated_kernel_shape - strides, zero) - right = _op.maximum(dilated_kernel_shape - mod, zero) - - total_pad = _op.where(_op.equal(mod, zero), left, right) - if deconv: - total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad - - # split total padding into before and after - pad_before = _op.floor_divide(total_pad, two) - pad_after = total_pad - pad_before - - # combine - if "LOWER" in mode: - pad = _op.concatenate( - [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 - ) - else: - pad = _op.concatenate( - [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 - ) - - # pad N and C with zeros - pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - - if isinstance(pad_value, (float, int)): - pad_value = _op.const(pad_value) - - return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) - - class Conv(OnnxOpConverter): """Operator converter for Conv.""" @@ -501,7 +439,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -582,7 +519,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, deconv=True, mode=attr["auto_pad"], ) @@ -974,7 +910,6 @@ def _impl_v1(cls, inputs, attr, params): attr["strides"], attr["kernel_shape"], [1] * ndim, - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -1410,14 +1345,6 @@ def _impl_v9(cls, inputs, attr, params): return out -def shape_of(x, dtype="int64"): - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(shape, dtype) - return _op.shape_of(x, dtype) - - class Shape(OnnxOpConverter): """Operator converter for Shape.""" @@ -3440,7 +3367,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=x_zero_point.data, mode=attr["auto_pad"], ) @@ -3810,7 +3736,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=data_zp, mode=attr["auto_pad"], ) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 70a46c966b84..e7f9159efedd 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -32,11 +32,13 @@ from .. import ty as _ty from .. import op as _op from .common import ( + autopad, fold_constant, get_relay_op, infer_shape, infer_type, infer_value, + shape_of, try_infer_value, new_var, ) @@ -44,58 +46,6 @@ __all__ = ["from_paddle"] -def _autopad( - data, - strides, - kernel_shape, - dilations=(1, 1), - pad_type="constant", - pad_value=0.0, -): - """Perform padding under SAME mode for dynamic and fixed input shapes. - This implementation refers to ONNX frontend. - """ - - # get attributes as constants - strides = _op.const(np.array(strides), dtype="int64") - dilated_kernel_shape = _op.const( - np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ), - dtype="int64", - ) - # get input shape - ndim = len(infer_shape(data)) - shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) - - # set up integer constants - zero = _op.const(0, dtype="int64") - two = _op.const(2, dtype="int64") - - # Calculate total padding - mod = _op.mod(shape, strides) - - left = _op.maximum(dilated_kernel_shape - strides, zero) - right = _op.maximum(dilated_kernel_shape - mod, zero) - - total_pad = _op.where(_op.equal(mod, zero), left, right) - - # split total padding into before and after - pad_before = _op.floor_divide(total_pad, two) - pad_after = total_pad - pad_before - - pad = _op.concatenate( - [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 - ) - - # pad N and C with zeros - pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - - if isinstance(pad_value, (float, int)): - pad_value = _op.const(pad_value) - return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) - - def _dtype_shape_promotion(inputs): """Promote data type and shape for list of tensors.""" @@ -117,16 +67,6 @@ def _dtype_shape_promotion(inputs): return inputs -def shape_of(x, dtype="int32"): - """Get shape of a tensor.""" - - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(np.array(shape), dtype) - return _op.shape_of(x, dtype) - - def _convert_dtype_value(val): """Converts a Paddle type id to a string.""" @@ -288,7 +228,7 @@ def convert_conv2d(g, op, block): paddings = [0, 0] elif padding_algorithm == "SAME": dilations = [1, 1] - input_x = _autopad(input_x, strides, [k_h, k_w], dilations) + input_x = autopad(input_x, strides, [k_h, k_w], dilations) paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -587,9 +527,9 @@ def convert_matmul(g, op, block): # This implemention almost keeps same with ONNX # Need to check input shape as batch matmul must be supported. - a_shape = shape_of(inputs[0]) + a_shape = shape_of(inputs[0], dtype="int32") a_rank = infer_shape(a_shape)[0] - b_shape = shape_of(inputs[1]) + b_shape = shape_of(inputs[1], dtype="int32") b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: @@ -676,8 +616,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = shape_of(x) - y_shape = shape_of(y) + x_shape = shape_of(x, dtype="int32") + y_shape = shape_of(y, dtype="int32") x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -781,7 +721,7 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - input_x = _autopad(input_x, strides, ksize) + input_x = autopad(input_x, strides, ksize) paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: @@ -877,7 +817,7 @@ def convert_shape(g, op, block): """Operator converter for shape.""" x = g.get_node(op.input("Input")[0]) - out = shape_of(x) + out = shape_of(x, dtype="int32") g.add_node(op.output("Out")[0], out) From 211022bd80327edc1e8528026e2b16919e8c1193 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 22 Oct 2021 11:14:09 +0000 Subject: [PATCH 8/8] add comment for conv2d --- python/tvm/relay/frontend/paddlepaddle.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index e7f9159efedd..ef361d6c55e8 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -227,6 +227,9 @@ def convert_conv2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": + # Handle history issue of PaddlePaddle + # while padding_algorithm == "SAME" + # dilations will be set to [1, 1] dilations = [1, 1] input_x = autopad(input_x, strides, [k_h, k_w], dilations) paddings = [0, 0]