Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][Onnx] Allow A to B broadcasting of batch_matmul and reverse strided slice #6681

Merged
merged 6 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,11 @@ def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
a_shape = _op.shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = _op.shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if infer_shape(a_shape)[0] > 2:
b_shape = _op.shape_of(inputs[1])
if a_rank > 2 or b_rank > 2:

def flatten_to_3d(x, x_shape):
ndims = infer_shape(x_shape)[0]
Expand All @@ -532,10 +534,18 @@ def flatten_to_3d(x, x_shape):
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Determine the output batch dimension.
if a_rank >= b_rank:
out_batch = _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 2])
else:
out_batch = _op.strided_slice(b_shape, [0], [infer_shape(b_shape)[0] - 2])
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
_op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]),
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
Expand Down Expand Up @@ -684,9 +694,7 @@ def _impl_v11(cls, inputs, attr, params):
else:
value = 0

pads_shape = infer_shape(pads)
dims = int(pads_shape[0] / 2)
pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims)))
pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1)))
pad_mode = attr.get("mode", b"constant").decode("utf-8")

if not pad_mode in ["constant", "edge", "reflect"]:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
else:
if end[i] > data_shape[i]:
cend = int64(data_shape[i])
elif end[i] < -data_shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
else:
if end[i] > data.shape[i]:
cend = int64(data.shape[i])
elif end[i] < -data.shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,10 @@ def dense_shape_func(attrs, inputs, _):
def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
if i == 0:
out[i] = max(data_shape[i], weight_shape[i])
else:
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]

return out
Expand Down
6 changes: 0 additions & 6 deletions src/relay/op/dyn/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
int pad_width_rank = pad_width->shape.size();
CHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D";

auto pad_width_dim1 = pad_width->shape[0].as<IntImmNode>();
auto pad_width_dim2 = pad_width->shape[1].as<IntImmNode>();

CHECK(pad_width_dim1->value == data_rank && pad_width_dim2->value == 2)
<< "Pad width must have shape (N, 2), where N is the rank of input data";

const PadAttrs* param = attrs.as<PadAttrs>();
CHECK(param != nullptr);

Expand Down
6 changes: 5 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,11 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
is_dyn = true;
oshape.push_back(Any());
} else {
oshape.push_back(x->shape[i]);
if (i == 0) {
oshape.push_back(max(x->shape[i], y->shape[i]));
} else {
oshape.push_back(x->shape[i]);
}
}
}
if (!is_dyn) {
Expand Down
15 changes: 8 additions & 7 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,10 +992,9 @@ def test_matmul():
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


def verify_batch_matmul(a_shape, b_shape, target, ctx):
def verify_batch_matmul(a_shape, b_shape, out_shape, target, ctx):
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")
out_np = np.matmul(a_array, b_array)

mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])

Expand All @@ -1006,21 +1005,23 @@ def verify_batch_matmul(a_shape, b_shape, target, ctx):
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)],
)

model = helper.make_model(graph, producer_name="matmul_test")
onnx_out = get_onnxruntime_output(model, [a_array, b_array], "float32")[0]

tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
@tvm.testing.parametrize_targets("llvm")
def test_batch_matmul(target, ctx):
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4), target, ctx)
verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, ctx)


def verify_simple_dynamic_model(a_shape, b_shape, target, ctx):
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,9 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
# Test backwards slicing.
verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3))
# Test slice mode.
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)
Expand Down