Skip to content

Commit

Permalink
Merge pull request #67 from heliqi/pr002
Browse files Browse the repository at this point in the history
Refactor autopad in the onnx.py and paddlepaddle.py to relay/frontend…
  • Loading branch information
jiangjiajun authored Oct 22, 2021
2 parents ab3278a + fee1886 commit 9c442a9
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 146 deletions.
74 changes: 74 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
from .. import ty as _ty
from .. import analysis

# pylint: disable=invalid-name
Expand Down Expand Up @@ -594,6 +595,16 @@ def try_infer_value(val, on_success=None, on_failure=None):
return val, False


def shape_of(x, dtype="int64"):
"""Get shape of a tensor."""

ttype = infer_type(x).checked_type
if not _ty.is_dynamic(ttype):
shape = list(ttype.shape)
return _expr.const(shape, dtype)
return _op.shape_of(x, dtype)


def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"):
return _expr.var(name_hint, type_annotation, shape, dtype)

Expand Down Expand Up @@ -837,6 +848,69 @@ def lstm_cell(
return outputs_list, hidden_state, cell_state


def autopad(
data,
strides,
kernel_shape,
dilations=(1, 1),
pad_type="constant",
deconv=False,
mode="SAME_UPPER",
pad_value=0.0,
):
"""
Perform autopadding with dynamic input shapes
"""
# get attributes as constants
strides = _op.const(np.array(strides), dtype="int64")
dilated_kernel_shape = _op.const(
np.array(
[(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)]
),
dtype="int64",
)
# get input shape
ndim = len(infer_shape(data))
shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim])

# set up integer constants
zero = _op.const(0, dtype="int64")
one = _op.const(1, dtype="int64")
two = _op.const(2, dtype="int64")

# Calculate total padding
mod = _op.mod(shape, strides)

left = _op.maximum(dilated_kernel_shape - strides, zero)
right = _op.maximum(dilated_kernel_shape - mod, zero)

total_pad = _op.where(_op.equal(mod, zero), left, right)
if deconv:
total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad

# split total padding into before and after
pad_before = _op.floor_divide(total_pad, two)
pad_after = total_pad - pad_before

# combine
if "LOWER" in mode:
pad = _op.concatenate(
[_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1
)
else:
pad = _op.concatenate(
[_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
)

# pad N and C with zeros
pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0)

if isinstance(pad_value, (float, int)):
pad_value = _op.const(pad_value)

return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type)


def ensure_scalar_shape(x):
"""
Assume that `x` is a tensor with one element (regardless of tensor rank).
Expand Down
79 changes: 2 additions & 77 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .. import ty as _ty
from .. import vision as _vision
from .common import (
autopad,
AttrCvt,
Renamer,
ensure_scalar_shape,
Expand All @@ -51,6 +52,7 @@
infer_value,
lstm_cell,
new_var,
shape_of,
try_resolve_var_to_const,
unbind,
)
Expand Down Expand Up @@ -315,7 +317,6 @@ def _run_calculation(cls, inputs, attr, params):
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
[1] * ndim,
ndim,
pad_value=pad_val,
mode=attr["auto_pad"],
)
Expand Down Expand Up @@ -411,69 +412,6 @@ def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name="instance_norm")(inputs, attr, params)


def autopad(
data,
strides,
kernel_shape,
dilations,
ndim,
pad_type="constant",
deconv=False,
mode="SAME_UPPER",
pad_value=0.0,
):
"""
Perform autopadding with dynamic input shapes
"""
# get attributes as constants
strides = _op.const(np.array(strides), dtype="int64")
dilated_kernel_shape = _op.const(
np.array(
[(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)]
),
dtype="int64",
)
# get input shape
shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim])

# set up integer constants
zero = _op.const(0, dtype="int64")
one = _op.const(1, dtype="int64")
two = _op.const(2, dtype="int64")

# Calculate total padding
mod = _op.mod(shape, strides)

left = _op.maximum(dilated_kernel_shape - strides, zero)
right = _op.maximum(dilated_kernel_shape - mod, zero)

total_pad = _op.where(_op.equal(mod, zero), left, right)
if deconv:
total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad

# split total padding into before and after
pad_before = _op.floor_divide(total_pad, two)
pad_after = total_pad - pad_before

# combine
if "LOWER" in mode:
pad = _op.concatenate(
[_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1
)
else:
pad = _op.concatenate(
[_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
)

# pad N and C with zeros
pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0)

if isinstance(pad_value, (float, int)):
pad_value = _op.const(pad_value)

return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type)


class Conv(OnnxOpConverter):
"""Operator converter for Conv."""

Expand Down Expand Up @@ -501,7 +439,6 @@ def _impl_v1(cls, inputs, attr, params):
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
ndim,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
Expand Down Expand Up @@ -582,7 +519,6 @@ def _impl_v1(cls, inputs, attr, params):
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
ndim,
deconv=True,
mode=attr["auto_pad"],
)
Expand Down Expand Up @@ -974,7 +910,6 @@ def _impl_v1(cls, inputs, attr, params):
attr["strides"],
attr["kernel_shape"],
[1] * ndim,
ndim,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
Expand Down Expand Up @@ -1410,14 +1345,6 @@ def _impl_v9(cls, inputs, attr, params):
return out


def shape_of(x, dtype="int64"):
ttype = infer_type(x).checked_type
if not _ty.is_dynamic(ttype):
shape = list(ttype.shape)
return _expr.const(shape, dtype)
return _op.shape_of(x, dtype)


class Shape(OnnxOpConverter):
"""Operator converter for Shape."""

Expand Down Expand Up @@ -3440,7 +3367,6 @@ def _impl_v10(cls, inputs, attr, params):
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
ndim,
pad_value=x_zero_point.data,
mode=attr["auto_pad"],
)
Expand Down Expand Up @@ -3810,7 +3736,6 @@ def _impl_v10(cls, inputs, attr, params):
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
ndim,
pad_value=data_zp,
mode=attr["auto_pad"],
)
Expand Down
78 changes: 9 additions & 69 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,70 +32,20 @@
from .. import ty as _ty
from .. import op as _op
from .common import (
autopad,
fold_constant,
get_relay_op,
infer_shape,
infer_type,
infer_value,
shape_of,
try_infer_value,
new_var,
)

__all__ = ["from_paddle"]


def _autopad(
data,
strides,
kernel_shape,
dilations=(1, 1),
pad_type="constant",
pad_value=0.0,
):
"""Perform padding under SAME mode for dynamic and fixed input shapes.
This implementation refers to ONNX frontend.
"""

# get attributes as constants
strides = _op.const(np.array(strides), dtype="int64")
dilated_kernel_shape = _op.const(
np.array(
[(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)]
),
dtype="int64",
)
# get input shape
ndim = len(infer_shape(data))
shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim])

# set up integer constants
zero = _op.const(0, dtype="int64")
two = _op.const(2, dtype="int64")

# Calculate total padding
mod = _op.mod(shape, strides)

left = _op.maximum(dilated_kernel_shape - strides, zero)
right = _op.maximum(dilated_kernel_shape - mod, zero)

total_pad = _op.where(_op.equal(mod, zero), left, right)

# split total padding into before and after
pad_before = _op.floor_divide(total_pad, two)
pad_after = total_pad - pad_before

pad = _op.concatenate(
[_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
)

# pad N and C with zeros
pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0)

if isinstance(pad_value, (float, int)):
pad_value = _op.const(pad_value)
return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type)


def _dtype_shape_promotion(inputs):
"""Promote data type and shape for list of tensors."""

Expand All @@ -117,16 +67,6 @@ def _dtype_shape_promotion(inputs):
return inputs


def shape_of(x, dtype="int32"):
"""Get shape of a tensor."""

ttype = infer_type(x).checked_type
if not _ty.is_dynamic(ttype):
shape = list(ttype.shape)
return _expr.const(np.array(shape), dtype)
return _op.shape_of(x, dtype)


def _convert_dtype_value(val):
"""Converts a Paddle type id to a string."""

Expand Down Expand Up @@ -288,7 +228,7 @@ def convert_conv2d(g, op, block):
paddings = [0, 0]
elif padding_algorithm == "SAME":
dilations = [1, 1]
input_x = _autopad(input_x, strides, [k_h, k_w], dilations)
input_x = autopad(input_x, strides, [k_h, k_w], dilations)
paddings = [0, 0]
elif padding_algorithm == "EXPLICIT":
if len(paddings) == 2:
Expand Down Expand Up @@ -587,9 +527,9 @@ def convert_matmul(g, op, block):

# This implemention almost keeps same with ONNX
# Need to check input shape as batch matmul must be supported.
a_shape = shape_of(inputs[0])
a_shape = shape_of(inputs[0], dtype="int32")
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_shape = shape_of(inputs[1], dtype="int32")
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:
Expand Down Expand Up @@ -676,8 +616,8 @@ def convert_mul(g, op, block):
y = g.get_node(op.input("Y")[0])
x_num_col_dims = op.attr("x_num_col_dims")
y_num_col_dims = op.attr("y_num_col_dims")
x_shape = shape_of(x)
y_shape = shape_of(y)
x_shape = shape_of(x, dtype="int32")
y_shape = shape_of(y, dtype="int32")
x_dim = infer_shape(x_shape)[0]
y_dim = infer_shape(y_shape)[0]
if x_num_col_dims < 0:
Expand Down Expand Up @@ -781,7 +721,7 @@ def convert_pool2d(g, op, block):
if padding_algorithm == "VALID":
paddings = [0, 0]
elif padding_algorithm == "SAME":
input_x = _autopad(input_x, strides, ksize)
input_x = autopad(input_x, strides, ksize)
paddings = [0, 0]
elif padding_algorithm == "EXPLICIT":
if len(paddings) == 2:
Expand Down Expand Up @@ -877,7 +817,7 @@ def convert_shape(g, op, block):
"""Operator converter for shape."""

x = g.get_node(op.input("Input")[0])
out = shape_of(x)
out = shape_of(x, dtype="int32")
g.add_node(op.output("Out")[0], out)


Expand Down

0 comments on commit 9c442a9

Please sign in to comment.