From 18e1afa6056f34f927b32bcf7283fad2aca35fca Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Oct 2020 01:50:58 +0000 Subject: [PATCH] formatting. --- python/tvm/relay/frontend/onnx.py | 5 ++++- python/tvm/relay/op/nn/_nn.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a224fc95a9f25..2c5b3bf7826d4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -518,6 +518,7 @@ def _impl_v1(cls, inputs, attr, params): b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: + def flatten_to_3d(x, x_shape): ndims = infer_shape(x_shape)[0] newshape = _op.concatenate( @@ -542,7 +543,9 @@ def flatten_to_3d(x, x_shape): final_shape = _op.concatenate( [ out_batch, - _op.strided_slice(a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]), + _op.strided_slice( + a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] + ), _op.strided_slice( b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] ), diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 35cbaf7fb5204..e1aabe1e15b59 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -983,7 +983,7 @@ def _batch_matmul_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") for i in const_range(out.shape[0] - 1): if i == 0: - out[i] = max(data_shape[i], weight_shape[i]) + out[i] = max(data_shape[i], weight_shape[i]) else: out[i] = data_shape[i] out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]