Skip to content

Commit

Permalink
[Relay] Recursive destructor call replaced with non-recursive for Cal…
Browse files Browse the repository at this point in the history
…l nodes. (apache#7832)

* [Relay] Recursive destructor call replaced with non-recursive for Call nodes.

Recursive destructor call replaced with non-recursive (based on
ExpandDataflow) for Call nodes. This prevents OutOfStack
exception during unwinding a chain of destructors for large-sized
subtrees based on smart-pointers.

Change-Id: Ib9da3ff8af3a0a41287b8ce9ab2bee2d0813d01c

Addressed requested changes

Addressed requested changes, simplified the code
added unit test

Change-Id: I7fdd44da3b6c366a555fd9157fa3630b6e789d64

* removed inline befor Call destructor

Change-Id: I6328e423670f185393d50ccd3d6fdc1326be3767
  • Loading branch information
d-smirnov authored and mehrdadh committed Apr 22, 2021
1 parent 86fbcee commit 3ecdadf
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
#include <tvm/ir/op.h>

#include <functional>
#include <stack>
#include <string>
#include <utility>

#include "./base.h"
#include "./type.h"
Expand Down Expand Up @@ -292,6 +294,11 @@ class CallNode : public ExprNode {

class Call : public Expr {
public:
/*!
* \brief The destructor
*/
~Call();

/*!
* \brief The constructor
* \param op The operator will be invoked.
Expand Down
58 changes: 58 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,5 +258,63 @@ TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp)

TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any(); });

/*
* Non-recursive traversal with dismantling unused call nodes,
* a derivative from ExpandDataflow method
*/
inline void Dismantle(const Expr& expr) {
std::stack<std::pair<Expr, bool>> stack;
auto fpush_to_stack = [&stack](const Expr& expr) {
// do not visit nodes with more than 2 refs (one can be in stack)
if (expr.use_count() < 3) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
const auto& node = stack.top().first;
if (stack.top().second) {
// dismantle node
// +1 ref in stack/deque;
if (node.use_count() < 3) {
if (auto* op = const_cast<CallNode*>(node.as<CallNode>())) {
op->args = Array<Expr>();
}
}
// eject
stack.pop();
} else {
stack.top().second = true;

// special handling
if (const CallNode* op = node.as<CallNode>()) {
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
fpush_to_stack(op->tuple);
}
}
}
}

/*
* Non-recursive destructor
*/

Call::~Call() {
// attempt to dismantle if referenced one or zero times
if (this->use_count() < 2) {
if (this->as<CallNode>() && this->as<CallNode>()->args.size()) {
Dismantle(*this);
}
}
}

} // namespace relay
} // namespace tvm
76 changes: 76 additions & 0 deletions tests/cpp/relay_dismantler_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type_functor.h>
#include <tvm/node/functor.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/generic/injective.h>

using namespace tvm;
using namespace tvm::relay;

TEST(Relay, OutOfStack_add) {
auto foo = [] {
auto add_op = relay::Op::Get("add");
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto c1 = relay::Constant(c_data);
Call y1 = relay::Call(add_op, {c1, c1});
for (int i = 0; i < 1e6; i++) {
y1 = relay::Call(add_op, {c1, y1});
}
relay::Function func = relay::Function({}, y1, relay::Type(), {});
};
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
}

TEST(Relay, OutOfStack_cast) {
auto foo = [] {
auto cast_op = relay::Op::Get("cast");
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto c1 = relay::Constant(c_data);
Call y1 = relay::Call(cast_op, {c1});
for (int i = 0; i < 1e6; i++) {
y1 = relay::Call(cast_op, {y1});
}
relay::Function func = relay::Function({}, y1, relay::Type(), {});
};
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}

0 comments on commit 3ecdadf

Please sign in to comment.