Skip to content

Commit

Permalink
[microNPU] Remove identity operations between non-compute operations (#…
Browse files Browse the repository at this point in the history
…10411)

Builds upon the work in #10254 to remove identity operations sandwiched
between two non-compute operations (reshape/strided slice - concatenate
is handled differently), under certain conditions. Specifically, an
identity operation is not removed when the dimensionality between the
two non-compute operations is reduced, due to non-congruent values
being accessed incorrectly. For example,

```
strided_slice(dims=4) -> identity -> reshape(dims=4)
```
becomes...
```
strided_slice -> reshape
```
but,
```
strided_slice(dims=4) -> identity -> reshape(dims=2)
```
remains as...
```
strided_slice -> identity -> reshape
```

Change-Id: Ie28ba384fcb3230d6f4651c0c19e2b9526ebcc42
  • Loading branch information
lhutton1 authored Mar 25, 2022
1 parent 9de36f7 commit 2cb7695
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 12 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
mod = OutlineCompilerFunctions("ethos-u")(mod)
mod = LegalizeEthosU()(mod)
mod = LUTsOptimizer()(mod)
mod = relay.transform.InferType()(mod)
mod = IdentityOptimizer()(mod)
mod = LayoutOptimizer()(mod)
mod = relay.transform.InferType()(mod)
Expand Down
61 changes: 54 additions & 7 deletions src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,33 @@ class RemoveRedundantIdentities : public MixedModeMutator {
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Call call = Downcast<Call>(post);

// only consider rewrite if current op is an NPU compute op.
// don't consider rewrite if current op is an identity or concatenate.
if (!call->op->IsInstance<OpNode>()) {
return post;
}
const auto* op = call->op.as<OpNode>();
std::string op_name = op->name;
if (op_name.substr(0, 15) != "contrib.ethosu." || op_name == "contrib.ethosu.identity") {
if (op_name == "contrib.ethosu.identity" || op_name == "concatenate") {
return post;
}

// check if we can rewrite parent identity operations to current call.
bool needs_rewrite = false;
Array<Expr> new_args;
for (const auto& arg : call->args) {
if (const auto* parent_callnode = arg.as<CallNode>()) {
Expr current_arg = arg;

// expand tuple to get parent op if we run into one - nested tuples are not supported.
if (const auto* tuple_get_item = arg.as<TupleGetItemNode>()) {
const auto* tuple = tuple_get_item->tuple.as<TupleNode>();
current_arg = tuple->fields[tuple_get_item->index];
}

if (const auto* parent_callnode = current_arg.as<CallNode>()) {
if (const auto* parent_op = parent_callnode->op.as<OpNode>()) {
Call parent_call = GetRef<Call>(parent_callnode);
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call)) {
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call) &&
CheckIdentityBetweenTransformOperations(call, parent_call)) {
needs_rewrite = true;
new_args.push_back(parent_call->args[0]);
continue;
Expand All @@ -143,7 +152,10 @@ class RemoveRedundantIdentities : public MixedModeMutator {
}

if (needs_rewrite) {
return Call(call->op, new_args, call->attrs, call->type_args);
Call new_call = Call(call->op, new_args, call->attrs, call->type_args);
// since we are only removing an identity, we know the type information has not changed
new_call->checked_type_ = call->checked_type_;
return new_call;
}
return post;
}
Expand All @@ -156,6 +168,41 @@ class RemoveRedundantIdentities : public MixedModeMutator {
bool has_no_activation = attrs->activation == "NONE";
return does_not_requantize && has_no_activation;
}

bool CheckIdentityBetweenTransformOperations(const Call& call, const Call& identity_call) {
const auto* op = call->op.as<OpNode>();
std::vector<std::string> nc_ops = {"reshape", "strided_slice"};

if (op && (std::find(nc_ops.begin(), nc_ops.end(), op->name) != nc_ops.end())) {
// check if the parent to identity operation is also a non-compute operation,
// if it isn't we can safely remove the identity in question by returning true.
const auto* identity_arg = identity_call->args[0].as<CallNode>();
if (!identity_arg) {
return true;
}
const auto* identity_arg_op = identity_arg->op.as<OpNode>();
if (!identity_arg_op ||
!(std::find(nc_ops.begin(), nc_ops.end(), identity_arg_op->name) != nc_ops.end())) {
return true;
}

const auto* call_tt = call->checked_type_.as<TensorTypeNode>();
const auto* identity_arg_tt = identity_arg->checked_type_.as<TensorTypeNode>();
CHECK(call_tt && identity_arg_tt)
<< "InferType should be run before RemoveRedundantIdentities";

// we can only remove the identity operation if the second non-compute operation
// in the sequence does not reduce the dimensionality of the output to the first
// non-compute operation. Doing so could lead to data being accessed incorrectly
// by the subsequent compute operation due to the reduction in dimensionality.
size_t first_transform_op_dims = identity_arg_tt->shape.size();
size_t second_transform_op_dims = call_tt->shape.size();
if (second_transform_op_dims < first_transform_op_dims) {
return false;
}
}
return true;
}
};

/*!
Expand All @@ -177,8 +224,8 @@ tvm::transform::Pass IdentityOptimizer() {
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0,
"relay.backend.contrib.ethos-u.IdentityOptimizer", {});
return tvm::transform::CreateModulePass(
pass_func, 0, "relay.backend.contrib.ethos-u.IdentityOptimizer", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay.ext.ethos-u.IdentityOptimizer").set_body_typed(IdentityOptimizer);
Expand Down
47 changes: 42 additions & 5 deletions tests/python/contrib/test_ethosu/test_identity_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ def test_many_output_identity():
def get_graph(get_expected=False):
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.reshape(x, newshape=(1, 1, 4, 4))
identity = infra.make_ethosu_identity(x)
if not get_expected:
x = infra.make_ethosu_identity(x)
outputs = []
for _ in range(4):
ifm = x if get_expected else identity
outputs.append(infra.make_ethosu_unary_elementwise(ifm, 4, "ABS"))
outputs.append(relay.strided_slice(identity, begin=(0, 0, 0, 0), end=(1, 1, 4, 4)))
outputs.append(infra.make_ethosu_unary_elementwise(x, 4, "ABS"))
ss = relay.strided_slice(x, begin=(0, 0, 0, 0), end=(1, 1, 4, 4))
identity_2 = infra.make_ethosu_identity(ss)
outputs.append(identity_2)
out = relay.concatenate(outputs, axis=0)
return relay.Function(relay.analysis.free_vars(out), out)

Expand Down Expand Up @@ -220,7 +222,8 @@ def test_identity_removal_with_multiple_transform_ops():
def get_graph(get_expected=False):
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 2, 2, 2])
x = infra.make_ethosu_identity(x)
if not get_expected:
x = infra.make_ethosu_identity(x)
x = relay.reshape(x, newshape=(1, 1, 1, 8))
if not get_expected:
x = infra.make_ethosu_identity(x)
Expand Down Expand Up @@ -267,6 +270,25 @@ def get_graph(get_expected=False):
_assert_structural_equal(actual, expected)


def test_multiple_transform_ops_with_reduction_in_dimensionality():
"""Removal of an identity operation between two transform operations is usually okay.
However, if the dimensionality of the input is reduced by the second transformation
operation, it can lead to an output mismatch. Checking that the pass doesn't remove
an identity given this case."""

def get_graph():
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.strided_slice(x, begin=(0, 0, 0, 0), end=(1, 2, 2, 2))
x = infra.make_ethosu_identity(x)
x = relay.reshape(x, newshape=(1, 2, 4))
x = infra.make_ethosu_identity(x)
return relay.Function(relay.analysis.free_vars(x), x)

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)


def test_identity_optimizer_runs_in_compilation_pipeline():
"""Checks that the identity optimization pass is run as part of the NPU compilation pipeline."""

Expand Down Expand Up @@ -320,3 +342,18 @@ def model(x):
return y

_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")


def test_multiple_transform_ops_same_output():
"""Check case of identity removal between transform ops and
then without, making sure they have the same output."""
ifm_shape = (1, 2, 2, 4)

@tf.function
def model(x):
x = tf.reshape(x, (1, 1, 4, 4))
x = tf.slice(x, (0, 0, 0, 0), (1, 1, 4, 3))
x = tf.reshape(x, (12,))
return x

_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")

0 comments on commit 2cb7695

Please sign in to comment.