Skip to content

Commit

Permalink
[Relay] Extract Intermediate Expr by relay expr ID for analysis
Browse files Browse the repository at this point in the history
modify doc comments
  • Loading branch information
Bin Li committed Sep 1, 2022
1 parent aa6c712 commit 710ecfa
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,41 @@ def get_calibration_data(mod, data):
calib_data[gvar] = value

return calib_data


def extract_intermdeiate_expr(mod, expr_id):
"""Extract Relay Expr by its expression ID
This function is used for extracting Relay Expr
by its expression ID of the main function
that we can see in `print(mod["main"])`.
Parameters
----------
mod : tvm.IRModule
expr_id : the Expr ID that we want to extract
Returns
-------
ret : Extracted IRModule
Examples
--------
.. code-block:: python
# Suppose our module is printed like this:
# def @main(%x: Tensor[(1, 1, 5, 1), float32], %w1, %w2) {
# %0 = nn.conv2d(%x, %w1, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
# %1 = nn.conv2d(%0, %w2, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
# %2 = add(%0, %1);
# %3 = split(%2, indices_or_sections=1);
# %4 = %3.0;
# add(%4, 1f)
# }
# if we want to extract `%1 = nn.conv2d`
from tvm import relay
relay.analysis.extract_intermdeiate_expr(mod, 1)
"""
return _ffi_api.ExtractIntermediateExpr(mod, expr_id)
88 changes: 88 additions & 0 deletions src/relay/analysis/extract_intermediate_expr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file extract_intermediate_expr.cc
* \brief Used for extracting Relay Expr
by the expression ID of the main function
that we can see in `print(mod["main"])`.
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>

namespace tvm {
namespace relay {

class ExtractIntermediateExprWrapper : private MixedModeVisitor {
public:
explicit ExtractIntermediateExprWrapper(const IRModule& mod, const int expr_id)
: mod_(mod), target_expr_id_(expr_id), counter_(0) {}

IRModule Extract() {
VisitExpr(this->mod_->Lookup("main"));

// ensure the target expr_id we want to extract is valid.
ICHECK(target_expr_id_ >= 0 && target_expr_id_ < counter_);

return IRModule::FromExpr(target_op_, {});
}

private:
using MixedModeVisitor::VisitExpr_;

const IRModule mod_;
/*! \brief the expr id that we want to extract. */
const int target_expr_id_;
int counter_;
Expr target_op_;

void VisitExpr_(const CallNode* n) final {
CheckCounterAndIncrease(GetRef<Expr>(n));
MixedModeVisitor::VisitExpr_(n);
}

void VisitExpr_(const TupleNode* n) final {
CheckCounterAndIncrease(GetRef<Expr>(n));
MixedModeVisitor::VisitExpr_(n);
}

void VisitExpr_(const TupleGetItemNode* n) final {
CheckCounterAndIncrease(GetRef<Expr>(n));
MixedModeVisitor::VisitExpr_(n);
}

void CheckCounterAndIncrease(const Expr& expr) {
if (target_expr_id_ == counter_) {
target_op_ = expr;
}
++counter_;
}
};

IRModule ExtractIntermediateExprPacked(const IRModule& mod, const int expr_id) {
return ExtractIntermediateExprWrapper(mod, expr_id).Extract();
}

TVM_REGISTER_GLOBAL("relay.analysis.ExtractIntermediateExpr")
.set_body_typed(ExtractIntermediateExprPacked);

} // namespace relay
} // namespace tvm
130 changes: 130 additions & 0 deletions tests/python/relay/test_analysis_extract_intermediate_expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test function extraction"""
import pytest
import tvm
from tvm import relay


def get_conv_net():
"""This gets the net for:
conv2d
/ |
/ |
conv2d |
\ |
\ |
elemwise add
|
|
|
split
|
|
|
elemwise add
"""
dshape = (1, 1, 5, 1)
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)

z = relay.add(y, x1)

tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)

tuple_0_add = relay.add(tuple_out[0], relay.const(1, dtype="float32"))

return tvm.IRModule.from_expr(tuple_0_add)


def get_conv2d():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
return tvm.IRModule.from_expr(y)


def test_extract():
dshape = (1, 1, 5, 1)

def before():
return get_conv_net()

def expected_0():
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
return tvm.IRModule.from_expr(y)

def expected_1():
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
return tvm.IRModule.from_expr(x1)

def expected_2():
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
z = relay.add(y, x1)
return tvm.IRModule.from_expr(z)

def expected_3():
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
z = relay.add(y, x1)
tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
return tvm.IRModule.from_expr(tuple_out.astuple())

def expected_4():
# check tuple node
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
z = relay.add(y, x1)
tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
return tvm.IRModule.from_expr(tuple_out[0])

assert tvm.ir.structural_equal(
relay.analysis.extract_intermdeiate_expr(before(), 0), expected_0()
)
assert tvm.ir.structural_equal(
relay.analysis.extract_intermdeiate_expr(before(), 1), expected_1()
)
assert tvm.ir.structural_equal(
relay.analysis.extract_intermdeiate_expr(before(), 2), expected_2()
)
assert tvm.ir.structural_equal(
(relay.analysis.extract_intermdeiate_expr(before(), 3)), expected_3()
)
assert tvm.ir.structural_equal(
relay.analysis.extract_intermdeiate_expr(before(), 4), expected_4()
)
assert tvm.ir.structural_equal(relay.analysis.extract_intermdeiate_expr(before(), 5), before())


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 710ecfa

Please sign in to comment.