From 9d6741c3dd29c4dde861aa1d3b2ca85f560f5ac6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 8 Apr 2022 09:35:51 +0900 Subject: [PATCH] [QNN] Fix broadcast for invalid axis --- src/relay/qnn/op/op_common.h | 13 +++++++++---- tests/python/relay/test_op_qnn_add.py | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 6d1eb3a343867..da01b88711f37 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -255,14 +255,19 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con } const BroadcastAttrs* broadcast_attrs = attrs.as(); - int lhs_axis = broadcast_attrs->lhs_axis; - int rhs_axis = broadcast_attrs->rhs_axis; + ICHECK(broadcast_attrs); auto lhs_rank = static_cast(lhs_data->shape.size()); auto rhs_rank = static_cast(rhs_data->shape.size()); - lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) : lhs_axis; - rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_rank + rhs_axis : 0) : rhs_axis; + auto get_broadcast_axis = [](int rank, int axis_from_attr) { + if (rank <= 1) return 0; + if (axis_from_attr < 0) return rank + axis_from_attr; + return axis_from_attr; + }; + + const int lhs_axis = get_broadcast_axis(lhs_rank, broadcast_attrs->lhs_axis); + const int rhs_axis = get_broadcast_axis(rhs_rank, broadcast_attrs->rhs_axis); // If zero point and scale are scalar then axis doesn't matter. bool lhs_scale_is_scalar = (types[2].as())->shape.size() == 0; diff --git a/tests/python/relay/test_op_qnn_add.py b/tests/python/relay/test_op_qnn_add.py index b38ada718cc59..44f09a321aa34 100644 --- a/tests/python/relay/test_op_qnn_add.py +++ b/tests/python/relay/test_op_qnn_add.py @@ -232,7 +232,31 @@ def test_saturation(): np.testing.assert_equal(op_res.numpy(), golden_output) +def test_ignore_broadcast_axis(): + data_dtype = "uint8" + + x = relay.var("x", shape=(4,), dtype=data_dtype) + y = relay.var("y", shape=(4,), dtype=data_dtype) + z = relay.qnn.op.add( + lhs=x, + rhs=y, + lhs_scale=relay.const(0.00784314, "float32"), + lhs_zero_point=relay.const(127, "int32"), + rhs_scale=relay.const(0.00784314, "float32"), + rhs_zero_point=relay.const(127, "int32"), + output_scale=relay.const(0.00784314, "float32"), + output_zero_point=relay.const(127, "int32"), + lhs_axis=1, + rhs_axis=1, + ) + + func = relay.Function([x, y], z) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + + if __name__ == "__main__": test_tflite_same_io_qnn_params() test_tflite_different_io_qnn_params() test_saturation() + test_ignore_broadcast_axis()