Skip to content

Commit

Permalink
[VM] Per-input, data dependence specification for shape func (apache#…
Browse files Browse the repository at this point in the history
…7210)

* made TShapeDataDependant array

* add stub

* dyn strided slice working

* reshape also working

* remove log

* works on maskrcnn

* lint fix

* fix cpp test

* remove stale pop back

* add more doc

* dependant -> dependent

* remove redundant check

* remove data_dependent_
  • Loading branch information
masahi authored and trevor-m committed Jan 21, 2021
1 parent 60e31e9 commit a150477
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 49 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ using TOpIsStateful = bool;
using TNonComputational = bool;

/*!
* \brief Mark the operator whether output shape is data dependant.
* \brief Mark the operator whether output shape is data dependent.
*/
using TShapeDataDependant = bool;
using TShapeDataDependent = Array<Integer>;

/*!
* \brief Computation description interface.
Expand Down
29 changes: 13 additions & 16 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@


@script
def _reshape_shape_func_input_data(data, newshape, ndim):
def _reshape_shape_func_input_data(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
Expand Down Expand Up @@ -87,7 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
return out


@_reg.register_shape_func("dyn.reshape", True)
@_reg.register_shape_func("dyn.reshape", [False, True])
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]

Expand Down Expand Up @@ -150,36 +147,36 @@ def one_hot_shape_func(attrs, inputs, _):


@script
def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
ndim = len(data.shape)
def _strided_slice_shape_func_input_data(data_shape, begin, end, strides, slice_mode):
ndim = len(data_shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
cbegin = int64(0)
cend = int64(data.shape[i])
cend = int64(data_shape[i])
cstride = int64(1)
if strides.shape[0] > i:
cstride = int64(strides[i])
if begin.shape[0] > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data.shape[i])
cbegin += int64(data_shape[i])
if end.shape[0] <= i:
cend = int64(data.shape[i])
cend = int64(data_shape[i])
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = int64(data.shape[i])
cend = int64(data_shape[i])
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data.shape[i]:
cend = int64(data.shape[i])
elif end[i] < -data.shape[i]:
if end[i] > data_shape[i]:
cend = int64(data_shape[i])
elif end[i] < -data_shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data.shape[i])
cend += int64(data_shape[i])
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
Expand All @@ -192,7 +189,7 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides, slice_mode):
return out


@_reg.register_shape_func("dyn.strided_slice", True)
@_reg.register_shape_func("dyn.strided_slice", [False, True, True, True])
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,18 @@ def register_gradient(op_name, fgradient=None, level=10):
return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)


def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
def register_shape_func(op_name, data_dependent, shape_func=None, level=10):
"""Register operator shape function for an op.
Parameters
----------
op_name : str
The name of the op.
data_dependant : bool
Whether the shape function depends on input data.
data_dependent : bool or list of bool
Whether the shape function depends on input data. If this is a list of bool,
the length of the list must be the same as the number of arguments of this op.
The list specifies per-input data dependence of the op.
shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
-> shape_tensors: List<Tensor>
Expand All @@ -374,7 +376,9 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
level : int
The priority level
"""
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
if not isinstance(data_dependent, list):
data_dependent = [data_dependent]
get(op_name).set_attr("TShapeDataDependent", data_dependent, level)
return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)


Expand Down
13 changes: 8 additions & 5 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,27 @@ bool IsDynamic(const Type& ty) {

TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);

bool IsDataDependant(const CallNode* call) {
static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
bool IsDataDependent(const CallNode* call) {
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
Op op = Downcast<Op>(call->op);

if (!tshape_data_dependant.count(op)) {
if (!tshape_data_dependent.count(op)) {
return false;
}

if (op->name == "strided_slice") {
if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
if (attrs->begin && attrs->end && attrs->strides) {
// not data dependant if begin, end and strides exist
// not data dependent if begin, end and strides exist
return false;
}
}
}

return tshape_data_dependant[op];
for (auto req : tshape_data_dependent[op]) {
if (req->value != 0) return true;
}
return false;
}
} // namespace relay
} // namespace tvm
44 changes: 27 additions & 17 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
LOG(FATAL) << "Free variable " << var->name_hint();
return {};
} else {
ICHECK(data_dependants_.size());
bool data_dependant = data_dependants_.back();
if (data_dependant) {
ICHECK(data_dependents_per_input_.size());
auto data_dependent = data_dependents_per_input_.back();
if (data_dependent) {
param_states_[var] |= kNeedInputData;
return param_data_[var];
} else {
Expand All @@ -449,12 +449,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>

Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(data_dependants_.size());
bool data_dependant = data_dependants_.back();
ICHECK(data_dependents_per_input_.size());
bool data_dependent = data_dependents_per_input_.back();
if (!op->is_scalar()) {
// This is a constant weight, extract the shape of the weight tensor.
// This can not be data dependent.
CHECK(!data_dependant);
CHECK(!data_dependent);
auto ttype = op->checked_type().as<TensorTypeNode>();
int ndim = static_cast<int>(ttype->shape.size());
Array<PrimExpr> out_shape{ndim};
Expand All @@ -472,7 +472,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
scalars_.push_back(value);
return {value};
}
if (data_dependant) {
if (data_dependent) {
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
auto value = tvm::te::compute(
Expand Down Expand Up @@ -507,27 +507,38 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
ICHECK(data_dependants_.empty() || !data_dependants_.back())
ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back())
<< "Error in op fusion: output of the shape func is fed to a "
<< "data-dependant shape func";
<< "data-dependent shape func";
ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name;
ICHECK_GT(tshape_data_dependant.count(op), 0)
<< "Internal error, cannot find TShapeDataDependant for " << op->name;
ICHECK_GT(tshape_data_dependent.count(op), 0)
<< "Internal error, cannot find TShapeDataDependent for " << op->name;

Array<Integer> dep_spec = tshape_data_dependent[op];
if (dep_spec.size() == 1) {
// This is for cases when data dependence is specified per op
// Replicate 0 or 1 flag to all arguments
for (size_t i = 1; i < call_node->args.size(); ++i) {
dep_spec.push_back(dep_spec[0]);
}
}

data_dependants_.push_back(IsDataDependant(call_node));
// Visit all inputs
Array<te::Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
for (size_t i = 0; i < call_node->args.size(); ++i) {
Expr arg = call_node->args[i];
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
data_dependents_per_input_.push_back(dep_spec[i]->value != 0);
for (te::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
data_dependents_per_input_.pop_back();
}
if (count_tuple) {
ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
Expand All @@ -549,7 +560,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
}
// Call shape function
auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
data_dependants_.pop_back();
readable_name_stream_ << "_" << op->name;
return outputs;
}
Expand Down Expand Up @@ -593,8 +603,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_shapes_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Stack of data dependencies for shape function, specified per each op input */
std::vector<bool> data_dependents_per_input_;
/*! \brief Scalars used in the shape function */
Array<te::Tensor> scalars_;
};
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) {
auto op = GetRef<Op>(opnode);
if (IsDynamic(call->checked_type()) && IsDataDependant(call)) {
if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
// output of a shape func can't be fed to a data-dependent shape func
op_pattern = kOpaque;
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
bool IsDynamic(const Type& ty);

/*!
* \brief Check if call is data dependant.
* \brief Check if call is data dependent.
* \param call The call to be checked.
* \return Whether the call is data dependant.
* \return Whether the call is data dependent.
*/
bool IsDataDependant(const CallNode* call);
bool IsDataDependent(const CallNode* call);

/*!
* \brief Make arbitrary transformation preserve the out most function.
Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/relay_build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ TEST(Relay, BuildModule) {
}
auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
(*reg)("add", "FTVMStrategy", fgeneric, 10);
(*reg)("add", "TShapeDataDependant", false, 10);
Array<Integer> dep;
dep.push_back(0);
(*reg)("add", "TShapeDataDependent", dep, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
Expand Down

0 comments on commit a150477

Please sign in to comment.