diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0598094398f7a..9fae94b5a8a1e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -513,9 +513,11 @@ def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) # Need to check input shape as batch matmul must be supported. a_shape = _op.shape_of(inputs[0]) + a_rank = infer_shape(a_shape)[0] + b_shape = _op.shape_of(inputs[1]) + b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. - if infer_shape(a_shape)[0] > 2: - b_shape = _op.shape_of(inputs[1]) + if a_rank > 2 or b_rank > 2: def flatten_to_3d(x, x_shape): ndims = infer_shape(x_shape)[0] @@ -532,10 +534,31 @@ def flatten_to_3d(x, x_shape): b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) + # Determine the output batch dimension. + if a_rank > b_rank: + out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) + elif a_rank < b_rank: + out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) + # If its unclear how broadcasting should be applied, the output + # shape is determined by choosing the maximum value from each input. + else: + out_batch = _op.concatenate( + [ + _op.maximum( + _op.strided_slice(a_shape, [i], [i + 1]), + _op.strided_slice(b_shape, [i], [i + 1]), + ) + for i in range(a_rank - 2) + ], + 0, + ) # Reshape output to original dimensions. final_shape = _op.concatenate( [ - _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), + out_batch, + _op.strided_slice( + a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] + ), _op.strided_slice( b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] ), @@ -684,9 +707,7 @@ def _impl_v11(cls, inputs, attr, params): else: value = 0 - pads_shape = infer_shape(pads) - dims = int(pads_shape[0] / 2) - pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims))) + pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1))) pad_mode = attr.get("mode", b"constant").decode("utf-8") if not pad_mode in ["constant", "edge", "reflect"]: diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index be0382aae8d2b..3b70d78cf967c 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -172,6 +172,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice else: if end[i] > data_shape[i]: cend = int64(data_shape[i]) + elif end[i] < -data_shape[i]: + cend = int64(-1) else: cend = int64(end[i]) if cend < 0: diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index dedd3dfb66d77..559d63acaefd7 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -173,6 +173,8 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode): else: if end[i] > data.shape[i]: cend = int64(data.shape[i]) + elif end[i] < -data.shape[i]: + cend = int64(-1) else: cend = int64(end[i]) if cend < 0: diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 9e47dc0a17f1a..e1aabe1e15b59 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -982,7 +982,10 @@ def dense_shape_func(attrs, inputs, _): def _batch_matmul_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") for i in const_range(out.shape[0] - 1): - out[i] = data_shape[i] + if i == 0: + out[i] = max(data_shape[i], weight_shape[i]) + else: + out[i] = data_shape[i] out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2] return out diff --git a/src/relay/op/dyn/nn/pad.cc b/src/relay/op/dyn/nn/pad.cc index 8a17f50df0dfe..73daccbd97fde 100644 --- a/src/relay/op/dyn/nn/pad.cc +++ b/src/relay/op/dyn/nn/pad.cc @@ -57,12 +57,6 @@ bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, int pad_width_rank = pad_width->shape.size(); CHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D"; - auto pad_width_dim1 = pad_width->shape[0].as(); - auto pad_width_dim2 = pad_width->shape[1].as(); - - CHECK(pad_width_dim1->value == data_rank && pad_width_dim2->value == 2) - << "Pad width must have shape (N, 2), where N is the rank of input data"; - const PadAttrs* param = attrs.as(); CHECK(param != nullptr); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index cbd7ae47acd78..58dfab27a933f 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -859,7 +859,11 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs is_dyn = true; oshape.push_back(Any()); } else { - oshape.push_back(x->shape[i]); + if (i == 0) { + oshape.push_back(max(x->shape[i], y->shape[i])); + } else { + oshape.push_back(x->shape[i]); + } } } if (!is_dyn) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ae32012e42e8b..07e6dc4652686 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -992,10 +992,9 @@ def test_matmul(): tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -def verify_batch_matmul(a_shape, b_shape, target, ctx): +def verify_batch_matmul(a_shape, b_shape, out_shape, target, ctx): a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") - out_np = np.matmul(a_array, b_array) mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) @@ -1006,21 +1005,26 @@ def verify_batch_matmul(a_shape, b_shape, target, ctx): helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_np.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], ) model = helper.make_model(graph, producer_name="matmul_test") + onnx_out = get_onnxruntime_output(model, [a_array, b_array], "float32")[0] tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) # TODO(mbrookhart): enable cuda once VM supports heterogenous execution @tvm.testing.parametrize_targets("llvm") def test_batch_matmul(target, ctx): - verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) - verify_batch_matmul((2, 4, 3), (3, 4), target, ctx) - verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx) + verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, ctx) + verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, ctx) + verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4), target, ctx) + # Test implicit broadcasting. + verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, ctx) + verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4), target, ctx) + verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4), target, ctx) def verify_simple_dynamic_model(a_shape, b_shape, target, ctx): diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 0df5a286b4e74..eafc743634d81 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -392,6 +392,9 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + # Test backwards slicing. + verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3)) + # Test slice mode. verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False )