Skip to content

Commit

Permalink
[QNN] Fix broadcast for invalid axis
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 8, 2022
1 parent 6ccde09 commit 9d6741c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,19 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
}

const BroadcastAttrs* broadcast_attrs = attrs.as<BroadcastAttrs>();
int lhs_axis = broadcast_attrs->lhs_axis;
int rhs_axis = broadcast_attrs->rhs_axis;
ICHECK(broadcast_attrs);

auto lhs_rank = static_cast<int>(lhs_data->shape.size());
auto rhs_rank = static_cast<int>(rhs_data->shape.size());

lhs_axis = (lhs_axis < 0) ? ((lhs_rank > 0) ? lhs_rank + lhs_axis : 0) : lhs_axis;
rhs_axis = (rhs_axis < 0) ? ((rhs_rank > 0) ? rhs_rank + rhs_axis : 0) : rhs_axis;
auto get_broadcast_axis = [](int rank, int axis_from_attr) {
if (rank <= 1) return 0;
if (axis_from_attr < 0) return rank + axis_from_attr;
return axis_from_attr;
};

const int lhs_axis = get_broadcast_axis(lhs_rank, broadcast_attrs->lhs_axis);
const int rhs_axis = get_broadcast_axis(rhs_rank, broadcast_attrs->rhs_axis);

// If zero point and scale are scalar then axis doesn't matter.
bool lhs_scale_is_scalar = (types[2].as<TensorTypeNode>())->shape.size() == 0;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_op_qnn_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,31 @@ def test_saturation():
np.testing.assert_equal(op_res.numpy(), golden_output)


def test_ignore_broadcast_axis():
data_dtype = "uint8"

x = relay.var("x", shape=(4,), dtype=data_dtype)
y = relay.var("y", shape=(4,), dtype=data_dtype)
z = relay.qnn.op.add(
lhs=x,
rhs=y,
lhs_scale=relay.const(0.00784314, "float32"),
lhs_zero_point=relay.const(127, "int32"),
rhs_scale=relay.const(0.00784314, "float32"),
rhs_zero_point=relay.const(127, "int32"),
output_scale=relay.const(0.00784314, "float32"),
output_zero_point=relay.const(127, "int32"),
lhs_axis=1,
rhs_axis=1,
)

func = relay.Function([x, y], z)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)


if __name__ == "__main__":
test_tflite_same_io_qnn_params()
test_tflite_different_io_qnn_params()
test_saturation()
test_ignore_broadcast_axis()

0 comments on commit 9d6741c

Please sign in to comment.