From 87128ed62cae7184d166c96f1bc10c65ae5901c1 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 2 Oct 2020 18:34:10 +0000 Subject: [PATCH 1/8] Allow batch_matmul to broadcast along batch dimension. --- include/tvm/topi/nn/batch_matmul.h | 67 ------------------- python/tvm/relay/frontend/onnx.py | 10 +-- python/tvm/topi/nn/batch_matmul.py | 13 ++-- python/tvm/topi/testing/batch_matmul.py | 7 +- python/tvm/topi/x86/batch_matmul.py | 10 +-- src/relay/op/nn/nn.cc | 7 +- src/topi/nn.cc | 6 -- tests/python/frontend/onnx/test_forward.py | 1 + .../topi/python/test_topi_batch_matmul.py | 21 +++--- tests/python/topi/python/test_topi_dense.py | 2 + 10 files changed, 41 insertions(+), 103 deletions(-) delete mode 100644 include/tvm/topi/nn/batch_matmul.h diff --git a/include/tvm/topi/nn/batch_matmul.h b/include/tvm/topi/nn/batch_matmul.h deleted file mode 100644 index bffddca8010f..000000000000 --- a/include/tvm/topi/nn/batch_matmul.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Batch matmul op constructions - * \file nn/batch_matmul.h - */ -#ifndef TVM_TOPI_NN_BATCH_MATMUL_H_ -#define TVM_TOPI_NN_BATCH_MATMUL_H_ - -#include -#include - -#include - -namespace tvm { -namespace topi { -namespace nn { - -using namespace tvm::te; - -/*! - * \brief Creates an operation that calculates matrix multiplication in batch. - * - * \param x Tensor with shape [batch, M, K] - * \param y Tensor with shape [batch, N, K] - * - * \return Tensor with shape [batch, M, N] - */ -inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, const tvm::te::Tensor& y) { - CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; - CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; - - auto batch = x->shape[0]; - auto M = x->shape[1]; - auto K = x->shape[2]; - auto N = y->shape[1]; - - auto k = tvm::te::reduce_axis(Range(0, K), "k"); - auto result = tvm::te::compute( - {batch, M, N}, [&](Var b, Var i, Var j) { return tvm::sum(x(b, i, k) * y(b, j, k), {k}); }, - "tensor", "batch_matmul"); - - return result; -} - -} // namespace nn -} // namespace topi -} // namespace tvm - -#endif // TVM_TOPI_NN_BATCH_MATMUL_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 841ff77b142d..dcacb683c22a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -499,18 +499,14 @@ def _impl_v1(cls, inputs, attr, params): # Convert a and b into 3 dimensional tensors. a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]]) - # Broadcast b to match batch size of a - new_b_shape = list(infer_shape(b)) - new_a_shape = infer_shape(a) - if new_a_shape[0] > new_b_shape[0]: - new_b_shape[0] = new_a_shape[0] - b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) + # Determine output batch dim. + batch = a_shape[0] if (len(a_shape) != len(b_shape)) else max(a_shape[0], b_shape[0]) # Reshape output to original dimensions. - return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + return _op.reshape(output, [batch, *a_shape[1:-2], a_shape[-2], b_shape[-1]]) # Otherwise a simple dense op will get the job done. input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 7c8fead569ae..7d663ffd484a 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -22,7 +22,7 @@ def batch_matmul(x, y): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. + data in batch. Supports broadcasting for batch dimension. Parameters ---------- @@ -40,11 +40,16 @@ def batch_matmul(x, y): assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" x_shape = get_const_tuple(x.shape) y_shape = get_const_tuple(y.shape) - assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + XB = x_shape[0] + YB = y_shape[0] + assert (XB == YB) or (XB == 1) or (YB == 1), "batch dimension doesn't match" assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch, M, K = x.shape + _, M, K = x.shape + batch = max(XB, YB) N = y.shape[1] k = te.reduce_axis((0, K), name="k") return te.compute( - (batch, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + (batch, M, N), + lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + tag="batch_matmul", ) diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py index 0a991f633c3c..a48c92967c77 100644 --- a/python/tvm/topi/testing/batch_matmul.py +++ b/python/tvm/topi/testing/batch_matmul.py @@ -35,9 +35,10 @@ def batch_matmul(x, y): out : numpy.ndarray 3-D with shape [batch, M, N] """ - batch, M, _ = x.shape - N = y.shape[1] + XB, M, _ = x.shape + YB, N, _ = y.shape + batch = max(XB, YB) out = np.zeros((batch, M, N)).astype(x.dtype) for i in range(batch): - out[i] = np.dot(x[i], y[i].T) + out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T) return out diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 333d3bed278b..668947e20c74 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -27,7 +27,7 @@ @autotvm.register_topi_compute("batch_matmul.x86") def batch_matmul(cfg, x, y): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. + data in batch. Supports broadcasting in batch dimension. Parameters ---------- @@ -45,16 +45,18 @@ def batch_matmul(cfg, x, y): assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) - assert XB == YB, "batch dimension doesn't match" + assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistant" - B = XB + B = max(XB, YB) K = XK if cfg.is_fallback: _default_batch_matmul_config(cfg, M, N, K) k = te.reduce_axis((0, K), name="k") C = te.compute( - (B, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + (B, M, N), + lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + tag="batch_matmul", ) return C diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 619b86d358d1..c257cc09e2ac 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -851,14 +851,15 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "BatchDot: batch dimension doesn't match, " - << " x shape=" << x->shape << ", y shape=" << y->shape; + //CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) + // << "BatchDot: batch dimension doesn't match, " + // << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " << " x shape=" << x->shape << ", y shape=" << y->shape; Array oshape = x->shape; + oshape.Set(0, max(x->shape[0], y->shape[0])); oshape.Set(2, y->shape[1]); // assign output type diff --git a/src/topi/nn.cc b/src/topi/nn.cc index c03d1b056d35..2c9546507de6 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -68,11 +67,6 @@ TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* r *rv = nn::bias_add(args[0], args[1], args[2]); }); -/* Ops from nn/batch_matmul.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::batch_matmul(args[0], args[1]); -}); - /* Ops from nn/dilate.h */ TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dilate(args[0], args[1], args[2]); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1c0fced6c3ef..f8e204c6cda2 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -945,6 +945,7 @@ def test_batch_matmul(): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4)) verify_batch_matmul((2, 4, 3), (3, 4)) verify_batch_matmul((2, 3, 4, 3), (3, 4)) + verify_batch_matmul((1, 4, 3), (2, 3, 4)) def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): diff --git a/tests/python/topi/python/test_topi_batch_matmul.py b/tests/python/topi/python/test_topi_batch_matmul.py index 769822e98488..0d82ee69fa26 100644 --- a/tests/python/topi/python/test_topi_batch_matmul.py +++ b/tests/python/topi/python/test_topi_batch_matmul.py @@ -32,16 +32,16 @@ } -def verify_batch_matmul(batch, M, N, K): - x = te.placeholder((batch, M, K), name="x") - y = te.placeholder((batch, N, K), name="y") +def verify_batch_matmul(x_batch, y_batch, M, N, K): + x = te.placeholder((x_batch, M, K), name="x") + y = te.placeholder((y_batch, N, K), name="y") dtype = x.dtype # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_batch_matmul") def get_ref_data(): - a_np = np.random.uniform(size=(batch, M, K)).astype(dtype) - b_np = np.random.uniform(size=(batch, N, K)).astype(dtype) + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) c_np = tvm.topi.testing.batch_matmul(a_np, b_np) return (a_np, b_np, c_np) @@ -67,10 +67,13 @@ def check_device(device, ctx): @tvm.testing.uses_gpu def test_batch_matmul(): - verify_batch_matmul(1, 16, 16, 32) - verify_batch_matmul(5, 16, 16, 32) - verify_batch_matmul(5, 16, 20, 32) - verify_batch_matmul(30, 16, 20, 32) + verify_batch_matmul(1, 1, 16, 16, 32) + verify_batch_matmul(5, 5, 16, 16, 32) + verify_batch_matmul(5, 5, 16, 20, 32) + verify_batch_matmul(30, 30, 16, 20, 32) + # Test batch broadcasting. + verify_batch_matmul(1, 5, 16, 16, 32) + verify_batch_matmul(5, 1, 16, 16, 32) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index f46a271ced53..c9beca3d58ee 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -75,6 +75,8 @@ def check_device(device, ctx): b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(c_np, ctx) d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) + print(tvm.lower(s, [A, B, C, D], simple_mode=True)) + exit() f = tvm.build(s, [A, B, C, D], device, name="dense") f(a, b, c, d) tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) From e0841ddb121025b963fbb1b4dc5939d6626d0534 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 2 Oct 2020 18:42:46 +0000 Subject: [PATCH 2/8] Added typerel checking. --- src/relay/op/nn/nn.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index c257cc09e2ac..74235572a854 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -851,9 +851,11 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); - //CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - // << "BatchDot: batch dimension doesn't match, " - // << " x shape=" << x->shape << ", y shape=" << y->shape; + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || + reporter->AssertEQ(x->shape[0], 1) || + reporter->AssertEQ(y->shape[0], 1)) + << "BatchDot: batch dimension doesn't match, " + << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " << " x shape=" << x->shape << ", y shape=" << y->shape; From 230765885ddff3a1c727d536aaa2c682c9eb3d37 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 2 Oct 2020 19:16:41 +0000 Subject: [PATCH 3/8] Fix style issue and respond to feedback. --- python/tvm/topi/nn/batch_matmul.py | 2 +- src/relay/op/nn/nn.cc | 5 +++-- tests/python/topi/python/test_topi_dense.py | 2 -- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 7d663ffd484a..c0383091ae69 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -42,7 +42,7 @@ def batch_matmul(x, y): y_shape = get_const_tuple(y.shape) XB = x_shape[0] YB = y_shape[0] - assert (XB == YB) or (XB == 1) or (YB == 1), "batch dimension doesn't match" + assert (XB == YB) or (XB == 1) or (YB == 1), "batch dimensions don't match" assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" _, M, K = x.shape batch = max(XB, YB) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 74235572a854..799531422ebe 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -35,6 +35,7 @@ #include #include +#include #include "../../transforms/infer_layout_util.h" #include "../make_op.h" @@ -851,10 +852,10 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) || reporter->AssertEQ(y->shape[0], 1)) - << "BatchDot: batch dimension doesn't match, " + << "BatchDot: batch dimensions don't match, " << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index c9beca3d58ee..f46a271ced53 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -75,8 +75,6 @@ def check_device(device, ctx): b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(c_np, ctx) d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) - print(tvm.lower(s, [A, B, C, D], simple_mode=True)) - exit() f = tvm.build(s, [A, B, C, D], device, name="dense") f(a, b, c, d) tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) From a3cce46302a2626b610eb66e4ef02891644bd1ce Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 2 Oct 2020 19:19:38 +0000 Subject: [PATCH 4/8] Fix style. --- src/relay/op/nn/nn.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 799531422ebe..57d5f98978f9 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -852,8 +852,7 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 3 && y->shape.size() == 3); - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || - reporter->AssertEQ(x->shape[0], 1) || + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) || reporter->AssertEQ(y->shape[0], 1)) << "BatchDot: batch dimensions don't match, " << " x shape=" << x->shape << ", y shape=" << y->shape; From a522773f6c9b7796662310864b910aaca8c53040 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 2 Oct 2020 19:25:26 +0000 Subject: [PATCH 5/8] More formatting issues :( --- src/relay/op/nn/nn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 57d5f98978f9..bc6c7a68e67e 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -33,9 +33,9 @@ #include #include +#include #include #include -#include #include "../../transforms/infer_layout_util.h" #include "../make_op.h" From f4cc2969e69c396f1e8aa36bcff8813b0cc843dd Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 5 Oct 2020 19:24:46 +0000 Subject: [PATCH 6/8] Fix issues after merge. --- python/tvm/relay/frontend/onnx.py | 3 +-- python/tvm/topi/nn/batch_matmul.py | 6 ++++-- tests/python/frontend/onnx/test_forward.py | 3 --- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index db950e0b593d..ce9267de52e2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -546,8 +546,7 @@ def flatten_to_3d(x, x_shape): # Compute output shape. final_shape = _op.concatenate( [ - _op.maximum(_op.strided_slice(a_shape, [0], [1]), _op.strided_slice(b_shape, [0], [1])), - _op.strided_slice(a_shape, [1], [infer_shape(a_shape)[0] - 1]), + _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), _op.strided_slice( b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] ), diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 52d6d550e222..9b926a1182d8 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -31,7 +31,7 @@ def batch_matmul(x, y, oshape=None): y : tvm.te.Tensor 3-D with shape [batch, N, K] - + oshape : List[Optional] Explicit intended output shape of the computation. Can be useful in cases with dynamic input shapes. @@ -55,5 +55,7 @@ def batch_matmul(x, y, oshape=None): N = y.shape[1] oshape = (batch, M, N) return te.compute( - oshape, lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), tag="batch_matmul" + oshape, + lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + tag="batch_matmul", ) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d6754bf9a976..1612a4ab1a27 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -985,7 +985,6 @@ def test_batch_matmul(target, ctx): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) verify_batch_matmul((2, 4, 3), (3, 4), target, ctx) verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx) - verify_batch_matmul((1, 4, 3), (2, 3, 4), target, ctx) def verify_simple_dynamic_model(a_shape, b_shape, target, ctx): @@ -1037,7 +1036,6 @@ def test_batch_matmul_dynamic_model(target, ctx): verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) verify_simple_dynamic_model((2, 4, 3), (3, 4), target, ctx) verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, ctx) - verify_simple_dynamic_model(1, 4, 3), (2, 3, 4), target, ctx) def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): @@ -3630,7 +3628,6 @@ def verify_roi_align( test_clip_min_max_as_inputs() test_onehot() test_matmul() - test_batch_matmul() test_gather() test_gatherelements() test_gather_nd() From 5c947966fa3817a7161a05ee57f0550afc30940a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 5 Oct 2020 19:26:17 +0000 Subject: [PATCH 7/8] Comment update. --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ce9267de52e2..f4cf9572b93f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -543,7 +543,7 @@ def flatten_to_3d(x, x_shape): b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. output = _op.nn.batch_matmul(a, b) - # Compute output shape. + # Reshape output to original dimensions. final_shape = _op.concatenate( [ _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 1]), From d6d7afec24e39e6aefd219bec415bbf5e65c07b4 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 5 Oct 2020 19:27:41 +0000 Subject: [PATCH 8/8] Small tweak. --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1612a4ab1a27..da8629dfcd2b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1036,7 +1036,7 @@ def test_batch_matmul_dynamic_model(target, ctx): verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, ctx) verify_simple_dynamic_model((2, 4, 3), (3, 4), target, ctx) verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, ctx) - + def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): in_array = np.random.uniform(size=shape).astype(dtype)