diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index c7b6c60849a1..524f69bcdd13 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -384,6 +384,23 @@ def extract_fused_functions(mod): return ret +def list_op_freqs(mod): + """Pass to extract unique operator names and how frequently they appear + in an IRModule. Fused functions are traversed to count the operators + that compose them. + + Parameters + ---------- + mod : tvm.IRModule + + Returns + ------- + ret : Dict[str, int] + Dict of unique operator names to frequency + """ + return _ffi_api.ExtractOperators(mod) + + def search_fc_transpose(expr): """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0])) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc new file mode 100644 index 000000000000..8fd0f87239ff --- /dev/null +++ b/src/relay/analysis/extract_operators.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_operators.cc + * \brief Extract unique operators from an IRModule + */ +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +class OperatorExtractorWrapper : private MixedModeVisitor { + public: + explicit OperatorExtractorWrapper(const IRModule& mod) : mod_(mod) {} + + Map Extract() { + VisitExpr(this->mod_->Lookup("main")); + + return operator_freqs_; + } + + private: + const IRModule mod_; + /*! \brief Map of operator to frequency. */ + Map operator_freqs_; + + void VisitExpr_(const CallNode* n) final { + VisitExpr(n->op); + + auto op = n->op.as(); + if (op) { + auto it = operator_freqs_.find(op->name); + ICHECK(it != operator_freqs_.end()) + << "Call's OpNode must be visited and registered before access"; + operator_freqs_.Set(op->name, 1 + operator_freqs_.at(op->name)); + } + + MixedModeVisitor::VisitExpr_(n); + } + + void VisitExpr_(const OpNode* n) final { + // NOTE: OpNode is visited only once for every operator kind + // regardless of how many times that op appears in the graph. + operator_freqs_.Set(n->name, 0U); + } +}; + +Map ExtractOperatorsPacked(const IRModule& mod) { + return OperatorExtractorWrapper(mod).Extract(); +} + +TVM_REGISTER_GLOBAL("relay.analysis.ExtractOperators").set_body_typed(ExtractOperatorsPacked); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_analysis_extract_operators.py b/tests/python/relay/test_analysis_extract_operators.py new file mode 100644 index 000000000000..5878b2a6e497 --- /dev/null +++ b/tests/python/relay/test_analysis_extract_operators.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test function extraction""" +import pytest +import tvm +from tvm import relay +from tvm.relay.testing.resnet import get_workload +from tvm.relay.testing import run_opt_pass + + +def get_conv_net(): + """This gets the net for: + conv2d + / | + / | + conv2d | + \ | + \ | + elemwise add + | + """ + dshape = (1, 1, 5, 1) + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) + + z = relay.add(y, x1) + + return tvm.IRModule.from_expr(z) + + +def get_conv2d(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + return tvm.IRModule.from_expr(y) + + +def test_extract_identity(): + mod = get_conv2d() + op_freqs = relay.analysis.list_op_freqs(mod) + assert len(op_freqs) == 1 + assert op_freqs["nn.conv2d"] == 1 + + +def test_extract_conv_net(): + mod = get_conv_net() + op_freqs = relay.analysis.list_op_freqs(mod) + assert len(op_freqs) == 2 + assert op_freqs["add"] == 1 + assert op_freqs["nn.conv2d"] == 2 + + +def test_extract_fused(): + mod = get_conv_net() + mod = relay.transform.InferType()(mod) + mod = relay.transform.FuseOps(3)(mod) + + op_freqs = relay.analysis.list_op_freqs(mod) + assert len(op_freqs) == 2 + assert op_freqs["add"] == 1 + assert op_freqs["nn.conv2d"] == 2 + + +def test_extract_resnet(): + mod, _params = get_workload() + expected_op_freqs = { + "nn.batch_norm": 19, + "nn.conv2d": 21, + "nn.relu": 18, + "nn.max_pool2d": 1, + "add": 8, + "nn.global_avg_pool2d": 1, + "nn.batch_flatten": 1, + "nn.dense": 1, + "nn.bias_add": 1, + "nn.softmax": 1, + } + op_freqs = relay.analysis.list_op_freqs(mod) + assert len(op_freqs) == len(expected_op_freqs) + assert all([op_freqs[op] == expected_op_freqs[op] for op in expected_op_freqs]) + + +if __name__ == "__main__": + pytest.main([__file__])