Skip to content

Commit

Permalink
[ETHOSN] Support multiply conversion to depthwise (#12403)
Browse files Browse the repository at this point in the history
Multiply can be supported when offloaded to the NPU by a conversion to a depthwise convolution operation. This is only supported when the multiply operation has a single single variable input with the other being a constant of shape [1, ..., C]. This commit adds a new pass "ConvertEquivalents" (name subject to change) to handle this conversion before codegen.
  • Loading branch information
lhutton1 authored Aug 24, 2022
1 parent 6e79f64 commit a0fe74b
Show file tree
Hide file tree
Showing 8 changed files with 582 additions and 18 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/contrib/_ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
import tvm._ffi

tvm._ffi._init_api("relay.ethos-n.support", __name__)
tvm._ffi._init_api("relay.backend.contrib.ethos-n", __name__)
80 changes: 62 additions & 18 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import is_constant, is_op, wildcard
from . import _ethosn as support
from . import _ethosn
from .register import register_pattern_table


Expand Down Expand Up @@ -60,6 +60,18 @@ def ethosn_api_version() -> str:
return tvm.get_global_func("relay.ethos-n.api.version")()


def ConvertEquivalents() -> tvm.ir.IRModule: # pylint: disable=invalid-name
"""Converts operations into a numerically equivalent form
that can be understood by the NPU codegen.
Return
------
Pass
The module pass.
"""
return _ethosn.ConvertEquivalents()


def partition_for_ethosn(mod, params=None, **opts):
"""Partition the graph greedily offloading supported
operators to Arm Ethos-N NPU.
Expand Down Expand Up @@ -107,9 +119,9 @@ def partition_for_ethosn(mod, params=None, **opts):
transform.AnnotateTarget("ethos-n"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
ConvertEquivalents(),
]
)

return seq(mod)


Expand Down Expand Up @@ -183,70 +195,102 @@ def qnn_resize_pattern():
)
return pattern

def qnn_mul_pattern():
"""
Multiply is supported when one input is a constant of shape [1, ..., C],
where C matches the number of channels of the other input.
"""
mul_op = is_op("qnn.mul")
gen_mul_inputs = lambda x, y: mul_op(
x,
y,
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
input_is_left = gen_mul_inputs(wildcard(), is_constant())
input_is_right = gen_mul_inputs(is_constant(), wildcard())
return input_is_left | input_is_right

def check_conv2d(extract):
"""Check if a conv2d is supported by Ethos-N."""
if not ethosn_available():
return False

return support.conv2d(extract)
return _ethosn.conv2d(extract)

def check_fc(extract):
"""Check if a fully connected is supported by Ethos-N."""
if not ethosn_available():
return False

return support.fc(extract)
return _ethosn.fc(extract)

def check_avg_pool2d(extract):
"""Check if a avg pool2d is supported by Ethos-N."""
if not ethosn_available():
return False

return support.avg_pool2d(extract)
return _ethosn.avg_pool2d(extract)

def check_mean(extract):
"""Check if mean is supported by Ethos-N."""
if not ethosn_available():
return False

return support.mean(extract)
return _ethosn.mean(extract)

def check_sigmoid(extract):
"""Check if a sigmoid is supported by Ethos-N."""
if not ethosn_available():
return False

return support.sigmoid(extract)
return _ethosn.sigmoid(extract)

def check_tanh(extract):
"""Check if tanh is supported by Ethos-N."""
if not ethosn_available():
return False

return support.tanh(extract)
return _ethosn.tanh(extract)

def check_leaky_relu(extract):
"""Check if Leaky ReLU is supported."""
if not ethosn_available():
return False

return support.leaky_relu(extract)
return _ethosn.leaky_relu(extract)

def check_mul(extract):
"""Check if Mul is supported."""
if not ethosn_available():
return False
# Do not support scalar constants for now
check_scalar = lambda i: isinstance(i, tvm.relay.Constant) and len(i.data.shape) == 0
if check_scalar(extract.args[0]) or check_scalar(extract.args[1]):
return False
extract = _ethosn.ConvertQnnMultiply(extract)
return _ethosn.conv2d(extract)

def check_requantize(extract):
"""Check if requantize is supported."""
if not ethosn_available():
return False

return support.requantize(extract)
return _ethosn.requantize(extract)

def check_resize(extract):
"""Check if resize (nearest neighbor) is supported."""
if not ethosn_available():
return False

return support.resize(extract)
return _ethosn.resize(extract)

return [
("ethos-n.qnn_mul", qnn_mul_pattern(), check_mul),
("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d),
("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid),
Expand Down Expand Up @@ -274,7 +318,7 @@ def max_pool2d(expr):
if not ethosn_available():
return False

return support.max_pool2d(expr)
return _ethosn.max_pool2d(expr)


@tvm.ir.register_op_attr("reshape", "target.ethos-n")
Expand All @@ -285,7 +329,7 @@ def reshape(expr):
if not _is_ethosn_composite(expr.args[0]):
return False

return support.reshape(expr)
return _ethosn.reshape(expr)


@tvm.ir.register_op_attr("qnn.add", "target.ethos-n")
Expand All @@ -294,15 +338,15 @@ def qnn_add(expr):
if not ethosn_available():
return False

return support.addition(expr)
return _ethosn.addition(expr)


@tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n")
def qnn_concatenate(expr):
"""Check if a concatenate is supported by Ethos-N."""
if not ethosn_available():
return False
if not support.concatenate(expr):
if not _ethosn.concatenate(expr):
return False

# Support library has some unenforced restrictions on qnn params
Expand Down Expand Up @@ -332,7 +376,7 @@ def split(expr):
return False
if ethosn_api_version() >= LooseVersion("3.0.1"):
return False
if not support.split(expr):
if not _ethosn.split(expr):
return False

return True
Expand All @@ -343,7 +387,7 @@ def depth_to_space(expr):
"""Check if a depth_to_space is supported by Ethos-N."""
if not ethosn_available():
return False
if not support.depth_to_space(expr):
if not _ethosn.depth_to_space(expr):
return False

return True
Expand All @@ -354,7 +398,7 @@ def clip(expr):
"""Check if a clip is supported by Ethos-N."""
if not ethosn_available():
return False
if not support.relu(expr):
if not _ethosn.relu(expr):
return False

return True
144 changes: 144 additions & 0 deletions src/relay/backend/contrib/ethosn/convert_equivalent.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/contrib/ethosn/convert_equivalent.cc
* \brief Converts operations into a numerically equivalent form
* that can be understood by the NPU codegen.
*/

#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>

#include <unordered_map>

#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"
#include "../../../transforms/simplify_expr.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace ethosn {

/*!
* \brief Converts qnn.mul to mathematically equivalent
* qnn.conv2d depthwise operation.
*/
Expr ConvertQnnMultiply(const Expr& expr) {
Call call = Downcast<Call>(expr);

Expr input1 = call->args[0];
Expr input2 = call->args[1];
Expr input1_scale = call->args[2];
Expr input1_zero_point = call->args[3];
Expr input2_scale = call->args[4];
Expr input2_zero_point = call->args[5];
// Reverse the inputs if the constant is first input
if (call->args[0]->IsInstance<ConstantNode>()) {
input1 = call->args[1];
input2 = call->args[0];
input1_scale = call->args[4];
input1_zero_point = call->args[5];
input2_scale = call->args[2];
input2_zero_point = call->args[3];
}
Expr output_scale = call->args[6];
Expr output_zero_point = call->args[7];

const auto* input_constant = input2.as<ConstantNode>();
ICHECK(input_constant) << "Expected ConstantNode but got " << input2->GetTypeKey();
const auto* input_constant_tt = input_constant->checked_type().as<TensorTypeNode>();
int channels = input_constant_tt->shape.back().as<IntImmNode>()->value;

runtime::NDArray input_data = input_constant->data;
runtime::NDArray kernel_data_hwoi =
runtime::NDArray::Empty({1, 1, channels, 1}, input_data->dtype, input_data->device);
kernel_data_hwoi.CopyFrom(input_data);
Constant kernel = Constant(kernel_data_hwoi, input_constant->span);

Type output_type = expr->checked_type();
auto output_tt = output_type.as<TensorTypeNode>();
ICHECK(output_tt) << "Expected TensorTypeNode but got " << output_type->GetTypeKey();
DataType output_dtype = output_tt->dtype;

Expr conv2d = qnn::MakeQnnConv2D(
input1, kernel, input1_zero_point, input2_zero_point, input1_scale, input2_scale, {1, 1},
{0, 0, 0, 0}, {1, 1}, channels, channels, {1, 1}, "NHWC", "HWOI", "NHWC", DataType::Int(32));
Constant bias_data = MakeConstantZeros(DataType::Int(32), {channels});
Expr bias_add = MakeBiasAdd(conv2d, bias_data, 3);
Expr requantize = qnn::MakeRequantize(bias_add, input1_scale, input1_zero_point, output_scale,
output_zero_point, -1, "None", "None", output_dtype);

return InferType(requantize);
}

TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertQnnMultiply")
.set_body_typed(ConvertQnnMultiply);

class ConvertEquivalentsMutator : public MixedModeMutator {
public:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Call call = Downcast<Call>(post);
if (!call->op->IsInstance<FunctionNode>()) {
return post;
}

Function func = Downcast<Function>(call->op);
Function new_func = Function(func);
auto composite_name = func->GetAttr<String>(attr::kComposite);
if (composite_name == "ethos-n.qnn_mul") {
Expr new_func_body = ConvertQnnMultiply(func->body);
new_func = WithFields(func, func->params, new_func_body);
new_func = WithAttr(std::move(new_func), attr::kComposite, String("ethos-n.qnn_conv2d"));
}

Call new_call = WithFields(call, new_func);
return Downcast<Expr>(new_call);
}
};

tvm::transform::Pass ConvertEquivalents() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule mod, transform::PassContext ctx) {
for (auto gv : mod->GetGlobalVars()) {
Function func = Downcast<Function>(mod->Lookup(gv));
auto compiler_name = func->GetAttr<String>(attr::kCompiler);
if (compiler_name.defined() && compiler_name == "ethos-n") {
auto new_body = ConvertEquivalentsMutator().VisitExpr(func->body);
if (!new_body.same_as(func->body)) {
Function new_func = WithFields(func, func->params, new_body);
mod->Update(gv, new_func);
}
}
}
return mod;
};
return tvm::transform::CreateModulePass(
pass_func, 0, "relay.backend.contrib.ethos-n.ConvertEquivalents", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertEquivalents")
.set_body_typed(ConvertEquivalents);

} // namespace ethosn
} // namespace contrib
} // namespace relay
} // namespace tvm
2 changes: 2 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ Expr MakeShapeOf(Expr data, DataType dtype);

Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode);

Expr MakeBiasAdd(Expr data, Expr bias, int axis);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
4 changes: 4 additions & 0 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
attrs.operator->(), input_shape, attrs->out_dtype);
}

Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
Expr output_zero_point, int axis, String rounding, String compute_dtype,
DataType out_dtype);

Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
const DequantizeAttrs* attrs);
Expand Down
Loading

0 comments on commit a0fe74b

Please sign in to comment.