diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 1e9b86d9e0bc..f916dbeb713f 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -83,9 +83,9 @@ using TOpIsStateful = bool; using TNonComputational = bool; /*! - * \brief Mark the operator whether output shape is data dependant. + * \brief Mark the operator whether output shape is data dependent. */ -using TShapeDataDependant = bool; +using TShapeDataDependent = Array; /*! * \brief Computation description interface. diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index b61d4f9655f6..a36b56214bc4 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -32,11 +32,8 @@ @script -def _reshape_shape_func_input_data(data, newshape, ndim): +def _reshape_shape_func_input_data(data_shape, newshape, ndim): out = output_tensor((ndim,), "int64") - data_shape = allocate((len(data.shape),), "int64") - for x in const_range(len(data.shape)): - data_shape[x] = int64(data.shape[x]) src_idx = 0 dst_idx = 0 infer_idx = -1 @@ -87,7 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim): return out -@_reg.register_shape_func("dyn.reshape", True) +@_reg.register_shape_func("dyn.reshape", [False, True]) def dynamic_reshape_shape_func(attrs, inputs, out_ndims): return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] @@ -150,36 +147,36 @@ def one_hot_shape_func(attrs, inputs, _): @script -def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode): - ndim = len(data.shape) +def _strided_slice_shape_func_input_data(data_shape, begin, end, strides, slice_mode): + ndim = len(data_shape) out = output_tensor((ndim,), "int64") for i in const_range(ndim): cbegin = int64(0) - cend = int64(data.shape[i]) + cend = int64(data_shape[i]) cstride = int64(1) if strides.shape[0] > i: cstride = int64(strides[i]) if begin.shape[0] > i: cbegin = int64(begin[i]) if cbegin < 0: - cbegin += int64(data.shape[i]) + cbegin += int64(data_shape[i]) if end.shape[0] <= i: - cend = int64(data.shape[i]) + cend = int64(data_shape[i]) elif slice_mode != 0: cstride = int64(1) if end[i] < 0: - cend = int64(data.shape[i]) + cend = int64(data_shape[i]) else: cend = cbegin + int64(end[i]) else: - if end[i] > data.shape[i]: - cend = int64(data.shape[i]) - elif end[i] < -data.shape[i]: + if end[i] > data_shape[i]: + cend = int64(data_shape[i]) + elif end[i] < -data_shape[i]: cend = int64(-1) else: cend = int64(end[i]) if cend < 0: - cend += int64(data.shape[i]) + cend += int64(data_shape[i]) assert cstride != 0, "Strides can't be zero." if cstride < 0: slice_range = cbegin - cend @@ -192,7 +189,7 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode): return out -@_reg.register_shape_func("dyn.strided_slice", True) +@_reg.register_shape_func("dyn.strided_slice", [False, True, True, True]) def strided_slice_shape_func(attrs, inputs, _): """ Shape func for strided_slice diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index d4d20b3ebc4a..5882027fb1d8 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -356,7 +356,7 @@ def register_gradient(op_name, fgradient=None, level=10): return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level) -def register_shape_func(op_name, data_dependant, shape_func=None, level=10): +def register_shape_func(op_name, data_dependent, shape_func=None, level=10): """Register operator shape function for an op. Parameters @@ -364,8 +364,10 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): op_name : str The name of the op. - data_dependant : bool - Whether the shape function depends on input data. + data_dependent : bool or list of bool + Whether the shape function depends on input data. If this is a list of bool, + the length of the list must be the same as the number of arguments of this op. + The list specifies per-input data dependence of the op. shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr]) -> shape_tensors: List @@ -374,7 +376,9 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): level : int The priority level """ - get(op_name).set_attr("TShapeDataDependant", data_dependant, level) + if not isinstance(data_dependent, list): + data_dependent = [data_dependent] + get(op_name).set_attr("TShapeDataDependent", data_dependent, level) return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level) diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index bcfbc83da514..abb9e6b034c2 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -473,24 +473,27 @@ bool IsDynamic(const Type& ty) { TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic); -bool IsDataDependant(const CallNode* call) { - static auto tshape_data_dependant = Op::GetAttrMap("TShapeDataDependant"); +bool IsDataDependent(const CallNode* call) { + static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); Op op = Downcast(call->op); - if (!tshape_data_dependant.count(op)) { + if (!tshape_data_dependent.count(op)) { return false; } if (op->name == "strided_slice") { if (const auto* attrs = call->attrs.as()) { if (attrs->begin && attrs->end && attrs->strides) { - // not data dependant if begin, end and strides exist + // not data dependent if begin, end and strides exist return false; } } } - return tshape_data_dependant[op]; + for (auto req : tshape_data_dependent[op]) { + if (req->value != 0) return true; + } + return false; } } // namespace relay } // namespace tvm diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index c969c3ba7f06..a66ae0a7e2c0 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -435,9 +435,9 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> LOG(FATAL) << "Free variable " << var->name_hint(); return {}; } else { - ICHECK(data_dependants_.size()); - bool data_dependant = data_dependants_.back(); - if (data_dependant) { + ICHECK(data_dependents_per_input_.size()); + auto data_dependent = data_dependents_per_input_.back(); + if (data_dependent) { param_states_[var] |= kNeedInputData; return param_data_[var]; } else { @@ -449,12 +449,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const ConstantNode* op) final { using tir::make_const; - ICHECK(data_dependants_.size()); - bool data_dependant = data_dependants_.back(); + ICHECK(data_dependents_per_input_.size()); + bool data_dependent = data_dependents_per_input_.back(); if (!op->is_scalar()) { // This is a constant weight, extract the shape of the weight tensor. // This can not be data dependent. - CHECK(!data_dependant); + CHECK(!data_dependent); auto ttype = op->checked_type().as(); int ndim = static_cast(ttype->shape.size()); Array out_shape{ndim}; @@ -472,7 +472,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> scalars_.push_back(value); return {value}; } - if (data_dependant) { + if (data_dependent) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); auto value = tvm::te::compute( @@ -507,27 +507,38 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const CallNode* call_node) final { static auto fshape_func = Op::GetAttrMap("FShapeFunc"); - static auto tshape_data_dependant = Op::GetAttrMap("TShapeDataDependant"); + static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - ICHECK(data_dependants_.empty() || !data_dependants_.back()) + ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) << "Error in op fusion: output of the shape func is fed to a " - << "data-dependant shape func"; + << "data-dependent shape func"; ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; - ICHECK_GT(tshape_data_dependant.count(op), 0) - << "Internal error, cannot find TShapeDataDependant for " << op->name; + ICHECK_GT(tshape_data_dependent.count(op), 0) + << "Internal error, cannot find TShapeDataDependent for " << op->name; + + Array dep_spec = tshape_data_dependent[op]; + if (dep_spec.size() == 1) { + // This is for cases when data dependence is specified per op + // Replicate 0 or 1 flag to all arguments + for (size_t i = 1; i < call_node->args.size(); ++i) { + dep_spec.push_back(dep_spec[0]); + } + } - data_dependants_.push_back(IsDataDependant(call_node)); // Visit all inputs Array inputs; int count_tuple = 0; - for (Expr arg : call_node->args) { + for (size_t i = 0; i < call_node->args.size(); ++i) { + Expr arg = call_node->args[i]; if (arg->checked_type().as()) { ++count_tuple; } + data_dependents_per_input_.push_back(dep_spec[i]->value != 0); for (te::Tensor tensor : VisitExpr(arg)) { inputs.push_back(tensor); } + data_dependents_per_input_.pop_back(); } if (count_tuple) { ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; @@ -549,7 +560,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } // Call shape function auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); - data_dependants_.pop_back(); readable_name_stream_ << "_" << op->name; return outputs; } @@ -593,8 +603,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; - /*! \brief Stack of data dependencies for shape function */ - std::vector data_dependants_; + /*! \brief Stack of data dependencies for shape function, specified per each op input */ + std::vector data_dependents_per_input_; /*! \brief Scalars used in the shape function */ Array scalars_; }; diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 29f3bfa0a17e..1b28980a0a2f 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -241,7 +241,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { OpPatternKind op_pattern = kOpaque; if (const OpNode* opnode = call->op.as()) { auto op = GetRef(opnode); - if (IsDynamic(call->checked_type()) && IsDataDependant(call)) { + if (IsDynamic(call->checked_type()) && IsDataDependent(call)) { // output of a shape func can't be fed to a data-dependent shape func op_pattern = kOpaque; } else { diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index a2f22cbbf106..bb2f268a23d7 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -90,11 +90,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map); bool IsDynamic(const Type& ty); /*! - * \brief Check if call is data dependant. + * \brief Check if call is data dependent. * \param call The call to be checked. - * \return Whether the call is data dependant. + * \return Whether the call is data dependent. */ -bool IsDataDependant(const CallNode* call); +bool IsDataDependent(const CallNode* call); /*! * \brief Make arbitrary transformation preserve the out most function. diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 3212f9079619..a15cdcd3926b 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -105,7 +105,9 @@ TEST(Relay, BuildModule) { } auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs); (*reg)("add", "FTVMStrategy", fgeneric, 10); - (*reg)("add", "TShapeDataDependant", false, 10); + Array dep; + dep.push_back(0); + (*reg)("add", "TShapeDataDependent", dep, 10); // build auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); tvm::runtime::Module build_mod = (*pfb)();