Skip to content

Commit

Permalink
formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm committed Oct 14, 2020
1 parent 47701fa commit be8728f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ def _impl_v1(cls, inputs, attr, params):
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:

def flatten_to_3d(x, x_shape):
ndims = infer_shape(x_shape)[0]
newshape = _op.concatenate(
Expand All @@ -542,7 +543,9 @@ def flatten_to_3d(x, x_shape):
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]),
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
if i == 0:
out[i] = max(data_shape[i], weight_shape[i])
out[i] = max(data_shape[i], weight_shape[i])
else:
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]
Expand Down

0 comments on commit be8728f

Please sign in to comment.