diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 3b38c07a0a8a..12f659f0037c 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -431,3 +431,41 @@ def get_calibration_data(mod, data): calib_data[gvar] = value return calib_data + + +def extract_intermdeiate_expr(mod, expr_id): + """Extract Relay Expr by its expression ID + + This function is used for extracting Relay Expr + by its expression ID of the main function + that we can see in `print(mod["main"])`. + + Parameters + ---------- + mod : tvm.IRModule + + expr_id : the Expr ID that we want to extract + + Returns + ------- + ret : Extracted IRModule + + Examples + -------- + .. code-block:: python + + # Suppose our module is printed like this: + # def @main(%x: Tensor[(1, 1, 5, 1), float32], %w1, %w2) { + # %0 = nn.conv2d(%x, %w1, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]); + # %1 = nn.conv2d(%0, %w2, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]); + # %2 = add(%0, %1); + # %3 = split(%2, indices_or_sections=1); + # %4 = %3.0; + # add(%4, 1f) + # } + # if we want to extract `%1 = nn.conv2d` + from tvm import relay + + relay.analysis.extract_intermdeiate_expr(mod, 1) + """ + return _ffi_api.ExtractIntermediateExpr(mod, expr_id) diff --git a/src/relay/analysis/extract_intermediate_expr.cc b/src/relay/analysis/extract_intermediate_expr.cc new file mode 100644 index 000000000000..d7466e2729db --- /dev/null +++ b/src/relay/analysis/extract_intermediate_expr.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_intermediate_expr.cc + * \brief Used for extracting Relay Expr + by the expression ID of the main function + that we can see in `print(mod["main"])`. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +class ExtractIntermediateExprWrapper : private MixedModeVisitor { + public: + explicit ExtractIntermediateExprWrapper(const IRModule& mod, const int expr_id) + : mod_(mod), target_expr_id_(expr_id), counter_(0) {} + + IRModule Extract() { + VisitExpr(this->mod_->Lookup("main")); + + // ensure the target expr_id we want to extract is valid. + ICHECK(target_expr_id_ >= 0 && target_expr_id_ < counter_); + + return IRModule::FromExpr(target_op_, {}); + } + + private: + using MixedModeVisitor::VisitExpr_; + + const IRModule mod_; + /*! \brief the expr id that we want to extract. */ + const int target_expr_id_; + int counter_; + Expr target_op_; + + void VisitExpr_(const CallNode* n) final { + CheckCounterAndIncrease(GetRef(n)); + MixedModeVisitor::VisitExpr_(n); + } + + void VisitExpr_(const TupleNode* n) final { + CheckCounterAndIncrease(GetRef(n)); + MixedModeVisitor::VisitExpr_(n); + } + + void VisitExpr_(const TupleGetItemNode* n) final { + CheckCounterAndIncrease(GetRef(n)); + MixedModeVisitor::VisitExpr_(n); + } + + void CheckCounterAndIncrease(const Expr& expr) { + if (target_expr_id_ == counter_) { + target_op_ = expr; + } + ++counter_; + } +}; + +IRModule ExtractIntermediateExprPacked(const IRModule& mod, const int expr_id) { + return ExtractIntermediateExprWrapper(mod, expr_id).Extract(); +} + +TVM_REGISTER_GLOBAL("relay.analysis.ExtractIntermediateExpr") + .set_body_typed(ExtractIntermediateExprPacked); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_analysis_extract_intermediate_expr.py b/tests/python/relay/test_analysis_extract_intermediate_expr.py new file mode 100644 index 000000000000..abcaf880b4aa --- /dev/null +++ b/tests/python/relay/test_analysis_extract_intermediate_expr.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test function extraction""" +import pytest +import tvm +from tvm import relay + + +def get_conv_net(): + """This gets the net for: + conv2d + / | + / | + conv2d | + \ | + \ | + elemwise add + | + | + | + split + | + | + | + elemwise add + """ + dshape = (1, 1, 5, 1) + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) + + z = relay.add(y, x1) + + tuple_out = relay.op.split(z, indices_or_sections=1, axis=0) + + tuple_0_add = relay.add(tuple_out[0], relay.const(1, dtype="float32")) + + return tvm.IRModule.from_expr(tuple_0_add) + + +def get_conv2d(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + return tvm.IRModule.from_expr(y) + + +def test_extract(): + dshape = (1, 1, 5, 1) + + def before(): + return get_conv_net() + + def expected_0(): + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + return tvm.IRModule.from_expr(y) + + def expected_1(): + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) + return tvm.IRModule.from_expr(x1) + + def expected_2(): + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) + z = relay.add(y, x1) + return tvm.IRModule.from_expr(z) + + def expected_3(): + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) + z = relay.add(y, x1) + tuple_out = relay.op.split(z, indices_or_sections=1, axis=0) + return tvm.IRModule.from_expr(tuple_out.astuple()) + + def expected_4(): + # check tuple node + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) + z = relay.add(y, x1) + tuple_out = relay.op.split(z, indices_or_sections=1, axis=0) + return tvm.IRModule.from_expr(tuple_out[0]) + + assert tvm.ir.structural_equal( + relay.analysis.extract_intermdeiate_expr(before(), 0), expected_0() + ) + assert tvm.ir.structural_equal( + relay.analysis.extract_intermdeiate_expr(before(), 1), expected_1() + ) + assert tvm.ir.structural_equal( + relay.analysis.extract_intermdeiate_expr(before(), 2), expected_2() + ) + assert tvm.ir.structural_equal( + (relay.analysis.extract_intermdeiate_expr(before(), 3)), expected_3() + ) + assert tvm.ir.structural_equal( + relay.analysis.extract_intermdeiate_expr(before(), 4), expected_4() + ) + assert tvm.ir.structural_equal(relay.analysis.extract_intermdeiate_expr(before(), 5), before()) + + +if __name__ == "__main__": + pytest.main([__file__])