Skip to content

Commit

Permalink
Reused conv2d pattern matching checks
Browse files Browse the repository at this point in the history
Change-Id: I1e032295668e68849a365da0dde28234c66cc056
  • Loading branch information
ashutosh-arm committed Aug 22, 2022
1 parent 29482e5 commit b0b48b3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 29 deletions.
29 changes: 1 addition & 28 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,12 @@ def check_qnn_conv2d_pad(pattern):
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]
conv2d_weight = conv2d.args[1]

# check if sum of paddings from pad() and conv2d() satisfies CMSIS-NN constraints
can_pad_be_fused = True
Expand All @@ -186,30 +182,7 @@ def check_qnn_conv2d_pad(pattern):
pad_h_diff = int(pad_bottom - pad_top)
can_pad_be_fused = pad_w_diff in [0, 1] and pad_h_diff in [0, 1]

# kernel zero_point should be 0
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

# check if depthwise Conv2D
kernel_layout = conv2d.attrs.kernel_layout
pos_o = kernel_layout.index("O")
groups = conv2d.attrs.groups
is_depthwise = False
if groups == int(conv2d_input.checked_type.shape[3]) and groups == int(
conv2d_weight.checked_type.shape[pos_o]
):
is_depthwise = True

ret = (
conv2d.attrs.out_dtype == "int32"
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and all([zp == 0 for zp in kernel_zp])
and (not is_depthwise or bias_add is not None)
and can_pad_be_fused
)
ret = check_qnn_conv2d(pattern) and can_pad_be_fused
return ret

def qnn_fully_connected_pattern():
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/contrib/cmsisnn/fuse_pads.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class FusePadsMutator : public MixedModeMutator {
explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {}

private:
/*! * \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
/*!
* \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
* convolution layer to update Conv2DAttrs's padding attribute. */
void UpdateConv2DPadding(const CallNode* conv2d_call, const CallNode* pad_call,
Attrs* new_attrs) {
Expand Down

0 comments on commit b0b48b3

Please sign in to comment.