Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Allow A to B broadcasting of batch_matmul and…
Browse files Browse the repository at this point in the history
… reverse strided slice (apache#6681)

* slice and batch_matmul fixes.

* Bug fix in shape inference.

* Test backwards strided slice.

* Fix batch_matmul dynamic shape function.

* formatting.

* Fix edge case for implicit broadcast
  • Loading branch information
jwfromm authored and Trevor Morris committed Dec 2, 2020
1 parent ce50cbd commit cfb5eb1
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 21 deletions.
33 changes: 27 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,11 @@ def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
a_shape = _op.shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = _op.shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if infer_shape(a_shape)[0] > 2:
b_shape = _op.shape_of(inputs[1])
if a_rank > 2 or b_rank > 2:

def flatten_to_3d(x, x_shape):
ndims = infer_shape(x_shape)[0]
Expand All @@ -532,10 +534,31 @@ def flatten_to_3d(x, x_shape):
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
_op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]),
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
Expand Down Expand Up @@ -684,9 +707,7 @@ def _impl_v11(cls, inputs, attr, params):
else:
value = 0

pads_shape = infer_shape(pads)
dims = int(pads_shape[0] / 2)
pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims)))
pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1)))
pad_mode = attr.get("mode", b"constant").decode("utf-8")

if not pad_mode in ["constant", "edge", "reflect"]:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
else:
if end[i] > data_shape[i]:
cend = int64(data_shape[i])
elif end[i] < -data_shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
else:
if end[i] > data.shape[i]:
cend = int64(data.shape[i])
elif end[i] < -data.shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,10 @@ def dense_shape_func(attrs, inputs, _):
def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
if i == 0:
out[i] = max(data_shape[i], weight_shape[i])
else:
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]

return out
Expand Down
6 changes: 0 additions & 6 deletions src/relay/op/dyn/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
int pad_width_rank = pad_width->shape.size();
CHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D";

auto pad_width_dim1 = pad_width->shape[0].as<IntImmNode>();
auto pad_width_dim2 = pad_width->shape[1].as<IntImmNode>();

CHECK(pad_width_dim1->value == data_rank && pad_width_dim2->value == 2)
<< "Pad width must have shape (N, 2), where N is the rank of input data";

const PadAttrs* param = attrs.as<PadAttrs>();
CHECK(param != nullptr);

Expand Down
6 changes: 5 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,11 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
is_dyn = true;
oshape.push_back(Any());
} else {
oshape.push_back(x->shape[i]);
if (i == 0) {
oshape.push_back(max(x->shape[i], y->shape[i]));
} else {
oshape.push_back(x->shape[i]);
}
}
}
if (!is_dyn) {
Expand Down
18 changes: 11 additions & 7 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,10 +992,9 @@ def test_matmul():
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


def verify_batch_matmul(a_shape, b_shape, target, ctx):
def verify_batch_matmul(a_shape, b_shape, out_shape, target, ctx):
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")
out_np = np.matmul(a_array, b_array)

mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])

Expand All @@ -1006,21 +1005,26 @@ def verify_batch_matmul(a_shape, b_shape, target, ctx):
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)],
)

model = helper.make_model(graph, producer_name="matmul_test")
onnx_out = get_onnxruntime_output(model, [a_array, b_array], "float32")[0]

tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
@tvm.testing.parametrize_targets("llvm")
def test_batch_matmul(target, ctx):
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4), target, ctx)
# Test implicit broadcasting.
verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4), target, ctx)
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4), target, ctx)


def verify_simple_dynamic_model(a_shape, b_shape, target, ctx):
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,9 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
# Test backwards slicing.
verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3))
# Test slice mode.
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
)
Expand Down

0 comments on commit cfb5eb1

Please sign in to comment.