Skip to content

Commit

Permalink
[Relay][Pass] Add ExtractOperators pass (apache#8996)
Browse files Browse the repository at this point in the history
* add extractor

* extract to array

* add comments

* lint

* Update tests/python/relay/test_analysis_extract_operators.py

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

* op freqs

* add comment

* Update python/tvm/relay/analysis/analysis.py

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

* oops

* mixedmode visitor

* oops

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 95860bb commit 45b0299
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,23 @@ def extract_fused_functions(mod):
return ret


def list_op_freqs(mod):
"""Pass to extract unique operator names and how frequently they appear
in an IRModule. Fused functions are traversed to count the operators
that compose them.
Parameters
----------
mod : tvm.IRModule
Returns
-------
ret : Dict[str, int]
Dict of unique operator names to frequency
"""
return _ffi_api.ExtractOperators(mod)


def search_fc_transpose(expr):
"""Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
Expand Down
75 changes: 75 additions & 0 deletions src/relay/analysis/extract_operators.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file extract_operators.cc
* \brief Extract unique operators from an IRModule
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>

namespace tvm {
namespace relay {

class OperatorExtractorWrapper : private MixedModeVisitor {
public:
explicit OperatorExtractorWrapper(const IRModule& mod) : mod_(mod) {}

Map<String, tvm::Integer> Extract() {
VisitExpr(this->mod_->Lookup("main"));

return operator_freqs_;
}

private:
const IRModule mod_;
/*! \brief Map of operator to frequency. */
Map<String, tvm::Integer> operator_freqs_;

void VisitExpr_(const CallNode* n) final {
VisitExpr(n->op);

auto op = n->op.as<OpNode>();
if (op) {
auto it = operator_freqs_.find(op->name);
ICHECK(it != operator_freqs_.end())
<< "Call's OpNode must be visited and registered before access";
operator_freqs_.Set(op->name, 1 + operator_freqs_.at(op->name));
}

MixedModeVisitor::VisitExpr_(n);
}

void VisitExpr_(const OpNode* n) final {
// NOTE: OpNode is visited only once for every operator kind
// regardless of how many times that op appears in the graph.
operator_freqs_.Set(n->name, 0U);
}
};

Map<String, tvm::Integer> ExtractOperatorsPacked(const IRModule& mod) {
return OperatorExtractorWrapper(mod).Extract();
}

TVM_REGISTER_GLOBAL("relay.analysis.ExtractOperators").set_body_typed(ExtractOperatorsPacked);

} // namespace relay
} // namespace tvm
107 changes: 107 additions & 0 deletions tests/python/relay/test_analysis_extract_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test function extraction"""
import pytest
import tvm
from tvm import relay
from tvm.relay.testing.resnet import get_workload
from tvm.relay.testing import run_opt_pass


def get_conv_net():
"""This gets the net for:
conv2d
/ |
/ |
conv2d |
\ |
\ |
elemwise add
|
"""
dshape = (1, 1, 5, 1)
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)

z = relay.add(y, x1)

return tvm.IRModule.from_expr(z)


def get_conv2d():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
return tvm.IRModule.from_expr(y)


def test_extract_identity():
mod = get_conv2d()
op_freqs = relay.analysis.list_op_freqs(mod)
assert len(op_freqs) == 1
assert op_freqs["nn.conv2d"] == 1


def test_extract_conv_net():
mod = get_conv_net()
op_freqs = relay.analysis.list_op_freqs(mod)
assert len(op_freqs) == 2
assert op_freqs["add"] == 1
assert op_freqs["nn.conv2d"] == 2


def test_extract_fused():
mod = get_conv_net()
mod = relay.transform.InferType()(mod)
mod = relay.transform.FuseOps(3)(mod)

op_freqs = relay.analysis.list_op_freqs(mod)
assert len(op_freqs) == 2
assert op_freqs["add"] == 1
assert op_freqs["nn.conv2d"] == 2


def test_extract_resnet():
mod, _params = get_workload()
expected_op_freqs = {
"nn.batch_norm": 19,
"nn.conv2d": 21,
"nn.relu": 18,
"nn.max_pool2d": 1,
"add": 8,
"nn.global_avg_pool2d": 1,
"nn.batch_flatten": 1,
"nn.dense": 1,
"nn.bias_add": 1,
"nn.softmax": 1,
}
op_freqs = relay.analysis.list_op_freqs(mod)
assert len(op_freqs) == len(expected_op_freqs)
assert all([op_freqs[op] == expected_op_freqs[op] for op in expected_op_freqs])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 45b0299

Please sign in to comment.