Skip to content

Commit

Permalink
support multiple output identities and add more tests
Browse files Browse the repository at this point in the history
Change-Id: Ib54031fe1c70159728876a23f96b72adb2ea17b0
  • Loading branch information
lhutton1 committed Feb 17, 2022
1 parent 8e43a6b commit 7d2b6c8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 53 deletions.
53 changes: 5 additions & 48 deletions src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,46 +102,12 @@ tvm::transform::Pass RelayToTIR() {
}

/*!
* \brief This visitor counts the number of outputs each identity operation has, since Relay doesn't
* keep references to child nodes.
*/
class CountIdentityOutputs : public MixedModeVisitor {
public:
void VisitExpr_(const CallNode* call) {
for (auto arg : call->args) {
if (const auto* parent_callnode = arg.as<CallNode>()) {
if (const auto* parent_op = parent_callnode->op.as<OpNode>()) {
if (parent_op->name != "contrib.ethosu.identity") {
continue;
}

Call parent_call = GetRef<Call>(parent_callnode);
Optional<Integer> current_count = output_count_.Get(parent_call);
if (current_count) {
output_count_.Set(parent_call, Integer(current_count.as<IntImmNode>()->value + 1));
} else {
output_count_.Set(parent_call, 1);
}
}
}
}
}

Map<Call, Integer> GetOutputCountMap() { return output_count_; }

private:
Map<Call, Integer> output_count_;
};

/*!
* \brief This mutator removes identity operations that are not necessary. Specifically, an identity
* operation can be removed when it is immediately followed by an NPU compute operation.
* \brief This mutator removes identity operations that are not necessary. Specifically, an
* identity operation can be removed when it is immediately followed by an NPU compute
* operation.
*/
class RemoveRedundantIdentities : public MixedModeMutator {
public:
explicit RemoveRedundantIdentities(Map<Call, Integer> identity_output_count)
: identity_output_count_(identity_output_count) {}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Call call = Downcast<Call>(post);

Expand All @@ -162,11 +128,7 @@ class RemoveRedundantIdentities : public MixedModeMutator {
if (const auto* parent_callnode = arg.as<CallNode>()) {
if (const auto* parent_op = parent_callnode->op.as<OpNode>()) {
Call parent_call = GetRef<Call>(parent_callnode);
// TODO(lhutton1) support removal of identities with multiple outputs.
bool has_single_output = identity_output_count_.Get(parent_call) == 1;

if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call) &&
has_single_output) {
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call)) {
needs_rewrite = true;
new_args.push_back(parent_call->args[0]);
continue;
Expand All @@ -190,8 +152,6 @@ class RemoveRedundantIdentities : public MixedModeMutator {
bool has_no_activation = attrs->activation == "NONE";
return does_not_requantize && has_no_activation;
}

Map<Call, Integer> identity_output_count_;
};

/*!
Expand All @@ -202,10 +162,7 @@ tvm::transform::Pass IdentityOptimizer() {
[=](IRModule mod, transform::PassContext ctx) {
for (auto gv : mod->GetGlobalVars()) {
Function main_func = Downcast<Function>(mod->Lookup(gv));
CountIdentityOutputs counter = CountIdentityOutputs();
counter.VisitExpr(main_func->body);
auto new_main_body =
RemoveRedundantIdentities(counter.GetOutputCountMap()).VisitExpr(main_func->body);
auto new_main_body = RemoveRedundantIdentities().VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
Function new_main_func = WithFields(main_func, main_func->params, new_main_body);
mod->Update(gv, new_main_func);
Expand Down
64 changes: 59 additions & 5 deletions tests/python/contrib/test_ethosu/test_identity_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,56 @@ def get_graph():
_assert_structural_equal(actual, expected)


def test_multiple_output_identity():
"""Check that an identity isn't removed when it has multiple outputs,
as this is not supported yet."""
def test_activation_identity_no_removal():
"""Check thst an identity with an activation isn't removed."""

def get_graph():
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = infra.make_ethosu_identity(x)
x = relay.reshape(x, newshape=(1, 1, 4, 4))
x = infra.make_ethosu_identity(x, activation="LUT")
x = infra.make_ethosu_unary_elementwise(x, 4, "ABS")
return relay.Function(relay.analysis.free_vars(x), x)

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)


def test_multiple_output_identity():
"""Check that an identity is removed when it has multiple outputs."""

def get_graph(get_expected=False):
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
if not get_expected:
x = infra.make_ethosu_identity(x)
y = infra.make_ethosu_unary_elementwise(x, 4, "ABS")
z = infra.make_ethosu_unary_elementwise(x, 4, "ABS")
out = relay.concatenate((y, z), axis=0)
return relay.Function(relay.analysis.free_vars(x), out)

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)


def test_many_output_identity():
"""Check an identity with many outputs. It cannot be removed due
to having a strided slice as output."""

def get_graph(get_expected=False):
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.reshape(x, newshape=(1, 1, 4, 4))
identity = infra.make_ethosu_identity(x)
outputs = []
for _ in range(4):
ifm = x if get_expected else identity
outputs.append(infra.make_ethosu_unary_elementwise(ifm, 4, "ABS"))
outputs.append(relay.strided_slice(identity, begin=(0, 0, 0, 0), end=(1, 1, 4, 4)))
out = relay.concatenate(outputs, axis=0)
return relay.Function(relay.analysis.free_vars(out), out)

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)


Expand Down Expand Up @@ -265,3 +301,21 @@ def model(x, y):
return z

_compare_tvm_with_tflite(model, ifm_shapes, "ethos-u55-256")


def test_multi_output_identity_has_same_output():
"""Check that the output remains the same with an identity with
multiple outputs."""
ifm_shape = (1, 1, 64, 16)

@tf.function
def model(x):
x = tf.reshape(x, (1, 8, 8, 16))
outputs = []
for _ in range(4):
outputs.append(tf.nn.max_pool2d(x, 1, 1, "VALID"))
outputs.append(tf.reshape(x, (1, 8, 8, 16)))
y = tf.concat(outputs, axis=0)
return y

_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")

0 comments on commit 7d2b6c8

Please sign in to comment.