Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Tensorflow]add batch_dim support for gatherV2 #7951

Merged
merged 5 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
std::string mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(batch_dims)
.set_default(0)
.describe("The batch_dims over which to select values.");
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
Expand Down
120 changes: 88 additions & 32 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,15 +763,17 @@ inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param batch_dims The number of batch dimensions.
* \param mode The mode of the operation.
* \param name The name of the operation.
* \param mode The mode of to handle out of bound indices.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip",
std::string name = "T_take", std::string tag = kInjective) {
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
Array<PrimExpr> a_shape = a->shape;
Array<PrimExpr> out_shape = indices->shape;
PrimExpr a_size = 1;
Expand Down Expand Up @@ -846,6 +848,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
*
* \param a The source array.
* \param indices The indices of the values to extract.
* \param batch_dims The number of batch dimensions. By default is 0.
* \param axis The axis over which to select values. By default,
* the flattened input array is used.
* \param mode The mode for handling out of bound indices.
Expand All @@ -854,46 +857,99 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip",
std::string name = "T_take", std::string tag = kInjective) {
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
ICHECK_GE(axis, 0) << "axis out of bounds";
ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];

int indices_len = static_cast<int>(indices->shape.size());
Array<PrimExpr> out_shape;
for (size_t i = 0; i < a->shape.size(); ++i) {
if (axis == static_cast<int>(i)) {
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
}
} else {
out_shape.push_back(a->shape[i]);

int batch_dims_ = batch_dims;
if (batch_dims_ != 0) {
ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";

if (batch_dims_ < 0) {
batch_dims_ = indices->shape.size() + batch_dims_;
}

ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
for (int i = 0; i < batch_dims_; ++i) {
auto addr1 = a->shape[i];
auto addr2 = indices->shape[i];
auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
}
}

// The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
// a.shape[axis + 1:].

Array<PrimExpr> out_shape;
for (int i = 0; i < batch_dims_; ++i) {
out_shape.push_back(a->shape[i]);
}
for (int i = batch_dims_; i < axis; ++i) {
out_shape.push_back(a->shape[i]);
}
for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
out_shape.push_back(indices->shape[i]);
}
for (size_t i = axis + 1; i < a->shape.size(); ++i) {
out_shape.push_back(a->shape[i]);
}

if (mode == "clip") {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
if (batch_dims_ == 0) {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
} else {
return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
indices_position.push_back(out_index[j]);
}
for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
indices_position.push_back(out_index[j]);
}
Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
},
name, tag);
}
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2002,14 +2002,19 @@ def _impl(inputs, attr, params, mod):
axis = _get_num_param(params, inputs.pop(2))
else:
axis = 0
batch_dims = 0
if int(attr.get("batch_dims", 0)) != 0:
raise tvm.error.OpAttributeUnImplemented("Attribute batch_dims is not supported")
batch_dims = int(attr.get("batch_dims", 0))
new_input = inputs[0:2]
return AttrCvt(
op_ = AttrCvt(
op_name="take",
extras={"axis": tvm.tir.const(axis, "int32")},
ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class", "batch_dims"],
extras={
"axis": tvm.tir.const(axis, "int32"),
"batch_dims": tvm.tir.const(batch_dims, "int32"),
},
ignores=["Tindices", "Tparams", "validate_indices", "Taxis", "_class"],
)(new_input, attr)
return op_

return _impl

Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _take_no_axis_shape_func(indices_shape, out_ndim):


@script
def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
def _take_with_axis_shape_func(data_shape, indices_shape, axis, batch_dims, out_ndim):
out = output_tensor((out_ndim,), "int64")
for i in const_range(axis):
out[i] = data_shape[i]
Expand All @@ -399,10 +399,10 @@ def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
for i in const_range(axis + 1, len(data_shape)):
out[i - 1] = data_shape[i]
else:
for i in const_range(len(indices_shape)):
out[axis + i] = indices_shape[i]
for i in const_range(len(indices_shape) - batch_dims):
out[axis + i] = indices_shape[i + batch_dims]
for i in const_range(axis + 1, len(data_shape)):
out[len(indices_shape) + i - 1] = data_shape[i]
out[len(indices_shape) + i - 1 - batch_dims] = data_shape[i]
return out


Expand All @@ -414,11 +414,16 @@ def take_shape_func(attrs, inputs, out_ndims):
if attrs.axis is None:
return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
axis = get_const_int(attrs.axis)
batch_dims = get_const_int(attrs.batch_dims)
data_ndim = int(inputs[0].shape[0])
if inputs[1].shape:
indicies_ndim = int(inputs[1].shape[0])
if axis < 0:
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
if batch_dims < 0:
batch_dims += indicies_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), convert(batch_dims), out_ndims[0])]


@_reg.register_legalize("take")
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_e
return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end)


def take(data, indices, axis=None, mode="clip"):
def take(data, indices, axis=None, batch_dims=0, mode="clip"):
"""Take elements from an array along an axis.

Parameters
Expand All @@ -403,6 +403,9 @@ def take(data, indices, axis=None, mode="clip"):
The axis over which to select values. By default,
the flattened input array is used.

batch_dims : int
The number of batch dimensions. By default is 0.

mode : str, optional
Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default).
Expand All @@ -414,7 +417,7 @@ def take(data, indices, axis=None, mode="clip"):
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis, mode)
return _make.take(data, indices, batch_dims, axis, mode)


def full(fill_value, shape=(), dtype=""):
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis)


def take(a, indices, axis=None, mode="clip"):
def take(a, indices, axis=None, batch_dims=0, mode="clip"):
"""Take elements from an array along an axis.
Parameters
Expand All @@ -411,6 +411,9 @@ def take(a, indices, axis=None, mode="clip"):
The axis over which to select values. By default,
the flattened input array is used.
batch_dims : int
The number of batch dimensions. By default is 0.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
Expand All @@ -422,8 +425,8 @@ def take(a, indices, axis=None, mode="clip"):
ret : tvm.te.Tensor
"""
if axis is None:
return cpp.take(a, indices, mode)
return cpp.take(a, indices, int(axis), mode)
return cpp.take(a, indices, int(batch_dims), mode)
return cpp.take(a, indices, int(batch_dims), int(axis), mode)


@tvm.target.generic_func
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);

Expr MakeShapeOf(Expr data, DataType dtype);

Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode);

} // namespace relay
} // namespace tvm
Expand Down
23 changes: 17 additions & 6 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,15 +1200,24 @@ bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
int axis = static_cast<int>(param->axis->value);
int batch_dims = static_cast<int>(param->batch_dims->value);
if (axis < 0) axis += ndim_data;
if (batch_dims < 0) axis += ndim_indices;
ICHECK_LE(axis, ndim_data) << "axis should be with in data shape"
<< ", but got = " << axis;
ICHECK_LE(batch_dims, ndim_indices) << "batch_dims should be with in indices shape"
<< ", but got = " << batch_dims;
ICHECK_LE(batch_dims, axis) << "batch_dims should be less than or equal to axis"
<< ", but got = " << batch_dims;

oshape.reserve(ndim_data - 1 + ndim_indices);
for (int i = 0; i < axis; ++i) {
oshape.reserve(ndim_data - 1 + ndim_indices - batch_dims);
for (int i = 0; i < batch_dims; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = batch_dims; i < axis; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = 0; i < ndim_indices; ++i) {
for (int i = batch_dims; i < ndim_indices; ++i) {
oshape.emplace_back(indices->shape[i]);
}
for (int i = axis + 1; i < ndim_data; ++i) {
Expand All @@ -1224,14 +1233,16 @@ Array<te::Tensor> TakeCompute(const Attrs& attrs, const Array<te::Tensor>& input
const auto* param = attrs.as<TakeAttrs>();
ICHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->mode)};
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->batch_dims, param->mode)};
} else {
return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->axis, param->mode)};
return Array<te::Tensor>{
topi::take(inputs[0], inputs[1], param->batch_dims, param->axis, param->mode)};
}
}

Expr MakeTake(Expr data, Expr indices, Integer axis, String mode) {
Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode) {
auto attrs = make_object<TakeAttrs>();
attrs->batch_dims = std::move(batch_dims);
attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("take");
Expand Down
15 changes: 9 additions & 6 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetVal
});

TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
if (args.size() == 4) {
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
int batch_dims = args[2];
*rv = take(args[0], args[1], batch_dims, mode);
} else {
ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments";
int batch_dims = args[2];
int axis = args[3];
std::string mode = args[4];
zxy844288792 marked this conversation as resolved.
Show resolved Hide resolved
*rv = take(args[0], args[1], batch_dims, axis, mode);
}
});

Expand Down
Loading