Skip to content

Commit

Permalink
remove data_dependent_
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 15, 2021
1 parent 94ced0d commit 6c1b318
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
15 changes: 6 additions & 9 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
LOG(FATAL) << "Free variable " << var->name_hint();
return {};
} else {
ICHECK(data_dependents_.size());
ICHECK(data_dependents_per_input_.size());
auto data_dependent = data_dependents_per_input_.back();
if (data_dependent) {
param_states_[var] |= kNeedInputData;
Expand All @@ -449,8 +449,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>

Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(data_dependents_.size());
bool data_dependent = data_dependents_.back();
ICHECK(data_dependents_per_input_.size());
bool data_dependent = data_dependents_per_input_.back();
if (!op->is_scalar()) {
// This is a constant weight, extract the shape of the weight tensor.
// This can not be data dependent.
Expand Down Expand Up @@ -510,17 +510,17 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
ICHECK(data_dependents_.empty() || !data_dependents_.back())
ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back())
<< "Error in op fusion: output of the shape func is fed to a "
<< "data-dependent shape func";
ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name;
ICHECK_GT(tshape_data_dependent.count(op), 0)
<< "Internal error, cannot find TShapeDataDependent for " << op->name;

data_dependents_.push_back(IsDataDependent(call_node));

Array<Integer> dep_spec = tshape_data_dependent[op];
if (dep_spec.size() == 1) {
// This is for cases when data dependence is specified per op
// Replicate 0 or 1 flag to all arguments
for (size_t i = 1; i < call_node->args.size(); ++i) {
dep_spec.push_back(dep_spec[0]);
}
Expand Down Expand Up @@ -560,7 +560,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
}
// Call shape function
auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
data_dependents_.pop_back();
readable_name_stream_ << "_" << op->name;
return outputs;
}
Expand Down Expand Up @@ -604,8 +603,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_shapes_;
/*! \brief Stack of data dependencies for shape function, specified per op */
std::vector<bool> data_dependents_;
/*! \brief Stack of data dependencies for shape function, specified per each op input */
std::vector<bool> data_dependents_per_input_;
/*! \brief Scalars used in the shape function */
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) {
auto op = GetRef<Op>(opnode);
if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
// output of a shape func can't be fed to a data-dependent shape func
op_pattern = kOpaque;
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/relay_build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ TEST(Relay, BuildModule) {
(*reg)("add", "FTVMStrategy", fgeneric, 10);
Array<Integer> dep;
dep.push_back(0);
(*reg)("add", "TShapeDataDependant", dep, 10);
(*reg)("add", "TShapeDataDependent", dep, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
Expand Down

0 comments on commit 6c1b318

Please sign in to comment.