Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VM] Per-input, data dependence specification for shape func #7210

Merged
merged 13 commits into from
Jan 15, 2021
4 changes: 2 additions & 2 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ using TOpIsStateful = bool;
using TNonComputational = bool;

/*!
* \brief Mark the operator whether output shape is data dependant.
* \brief Mark the operator whether output shape is data dependent.
*/
using TShapeDataDependant = bool;
using TShapeDataDependent = Array<Integer>;

/*!
* \brief Computation description interface.
Expand Down
29 changes: 13 additions & 16 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@


@script
def _reshape_shape_func_input_data(data, newshape, ndim):
def _reshape_shape_func_input_data(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
Expand Down Expand Up @@ -87,7 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
return out


@_reg.register_shape_func("dyn.reshape", True)
@_reg.register_shape_func("dyn.reshape", [False, True])
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]

Expand Down Expand Up @@ -150,36 +147,36 @@ def one_hot_shape_func(attrs, inputs, _):


@script
def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
ndim = len(data.shape)
def _strided_slice_shape_func_input_data(data_shape, begin, end, strides, slice_mode):
ndim = len(data_shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
cbegin = int64(0)
cend = int64(data.shape[i])
cend = int64(data_shape[i])
cstride = int64(1)
if strides.shape[0] > i:
cstride = int64(strides[i])
if begin.shape[0] > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data.shape[i])
cbegin += int64(data_shape[i])
if end.shape[0] <= i:
cend = int64(data.shape[i])
cend = int64(data_shape[i])
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = int64(data.shape[i])
cend = int64(data_shape[i])
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data.shape[i]:
cend = int64(data.shape[i])
elif end[i] < -data.shape[i]:
if end[i] > data_shape[i]:
cend = int64(data_shape[i])
elif end[i] < -data_shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data.shape[i])
cend += int64(data_shape[i])
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
Expand All @@ -192,7 +189,7 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
return out


@_reg.register_shape_func("dyn.strided_slice", True)
@_reg.register_shape_func("dyn.strided_slice", [False, True, True, True])
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,18 @@ def register_gradient(op_name, fgradient=None, level=10):
return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)


def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
def register_shape_func(op_name, data_dependent, shape_func=None, level=10):
"""Register operator shape function for an op.

Parameters
----------
op_name : str
The name of the op.

data_dependant : bool
Whether the shape function depends on input data.
data_dependent : bool or list of bool
Whether the shape function depends on input data. If this is a list of bool,
the length of the list must be the same as the number of arguments of this op.
The list specifies per-input data dependence of the op.

shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
-> shape_tensors: List<Tensor>
Expand All @@ -374,7 +376,9 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
level : int
The priority level
"""
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
if not isinstance(data_dependent, list):
data_dependent = [data_dependent]
get(op_name).set_attr("TShapeDataDependent", data_dependent, level)
return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)


Expand Down
13 changes: 8 additions & 5 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,27 @@ bool IsDynamic(const Type& ty) {

TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);

bool IsDataDependant(const CallNode* call) {
static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
bool IsDataDependent(const CallNode* call) {
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
Op op = Downcast<Op>(call->op);

if (!tshape_data_dependant.count(op)) {
if (!tshape_data_dependent.count(op)) {
return false;
}

if (op->name == "strided_slice") {
if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
if (attrs->begin && attrs->end && attrs->strides) {
// not data dependant if begin, end and strides exist
// not data dependent if begin, end and strides exist
return false;
}
}
}

return tshape_data_dependant[op];
for (auto req : tshape_data_dependent[op]) {
if (req->value != 0) return true;
}
return false;
}
} // namespace relay
} // namespace tvm
44 changes: 27 additions & 17 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
LOG(FATAL) << "Free variable " << var->name_hint();
return {};
} else {
ICHECK(data_dependants_.size());
bool data_dependant = data_dependants_.back();
if (data_dependant) {
ICHECK(data_dependents_per_input_.size());
auto data_dependent = data_dependents_per_input_.back();
if (data_dependent) {
param_states_[var] |= kNeedInputData;
return param_data_[var];
} else {
Expand All @@ -449,12 +449,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>

Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(data_dependants_.size());
bool data_dependant = data_dependants_.back();
ICHECK(data_dependents_per_input_.size());
bool data_dependent = data_dependents_per_input_.back();
if (!op->is_scalar()) {
// This is a constant weight, extract the shape of the weight tensor.
// This can not be data dependent.
CHECK(!data_dependant);
CHECK(!data_dependent);
auto ttype = op->checked_type().as<TensorTypeNode>();
int ndim = static_cast<int>(ttype->shape.size());
Array<PrimExpr> out_shape{ndim};
Expand All @@ -472,7 +472,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
scalars_.push_back(value);
return {value};
}
if (data_dependant) {
if (data_dependent) {
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
auto value = tvm::te::compute(
Expand Down Expand Up @@ -507,27 +507,38 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
ICHECK(data_dependants_.empty() || !data_dependants_.back())
ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back())
<< "Error in op fusion: output of the shape func is fed to a "
<< "data-dependant shape func";
<< "data-dependent shape func";
ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name;
ICHECK_GT(tshape_data_dependant.count(op), 0)
<< "Internal error, cannot find TShapeDataDependant for " << op->name;
ICHECK_GT(tshape_data_dependent.count(op), 0)
<< "Internal error, cannot find TShapeDataDependent for " << op->name;

Array<Integer> dep_spec = tshape_data_dependent[op];
if (dep_spec.size() == 1) {
// This is for cases when data dependence is specified per op
// Replicate 0 or 1 flag to all arguments
for (size_t i = 1; i < call_node->args.size(); ++i) {
dep_spec.push_back(dep_spec[0]);
}
}

data_dependants_.push_back(IsDataDependant(call_node));
// Visit all inputs
Array<te::Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
for (size_t i = 0; i < call_node->args.size(); ++i) {
Expr arg = call_node->args[i];
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
data_dependents_per_input_.push_back(dep_spec[i]->value != 0);
for (te::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
data_dependents_per_input_.pop_back();
}
if (count_tuple) {
ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
Expand All @@ -549,7 +560,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
}
// Call shape function
auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
data_dependants_.pop_back();
readable_name_stream_ << "_" << op->name;
return outputs;
}
Expand Down Expand Up @@ -593,8 +603,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_shapes_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Stack of data dependencies for shape function, specified per each op input */
std::vector<bool> data_dependents_per_input_;
/*! \brief Scalars used in the shape function */
Array<te::Tensor> scalars_;
};
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) {
auto op = GetRef<Op>(opnode);
if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
// output of a shape func can't be fed to a data-dependent shape func
op_pattern = kOpaque;
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
bool IsDynamic(const Type& ty);

/*!
* \brief Check if call is data dependant.
* \brief Check if call is data dependent.
* \param call The call to be checked.
* \return Whether the call is data dependant.
* \return Whether the call is data dependent.
*/
bool IsDataDependant(const CallNode* call);
bool IsDataDependent(const CallNode* call);

/*!
* \brief Make arbitrary transformation preserve the out most function.
Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/relay_build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ TEST(Relay, BuildModule) {
}
auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
(*reg)("add", "FTVMStrategy", fgeneric, 10);
(*reg)("add", "TShapeDataDependant", false, 10);
Array<Integer> dep;
dep.push_back(0);
(*reg)("add", "TShapeDataDependent", dep, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
Expand Down