Skip to content

Commit

Permalink
[Topi] Fix GPU Dynamic Topk by Improving Dynamic Strided Slice in Topi (
Browse files Browse the repository at this point in the history
apache#7018)

* Fix GPU dynamic Topk

* Fix style

* Minor fix

* Simplfy dynamic checking

* Fix lint

* More improvements

* Disable test any topk
  • Loading branch information
kevinthesun authored and electriclilies committed Feb 18, 2021
1 parent effd4e3 commit 3cbb44d
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 33 deletions.
15 changes: 15 additions & 0 deletions include/tvm/topi/detail/constant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ using namespace tvm::te;
*/
inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); }

/*!
* \brief Test whether the given Array has every element as constant integer
*
* \param array the array to query
*
* \return true if every element in array is constant int or uint, false otherwise.
*/
inline bool IsConstIntArray(Array<PrimExpr> array) {
bool is_const_int = true;
for (auto const& elem : array) {
is_const_int &= elem->IsInstance<tvm::tir::IntImmNode>();
}
return is_const_int;
}

/*!
* \brief Get the value of the given constant integer expression. An error
* is logged if the given expression is not a constant integer.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
out = reshape(out, r_p_shape);

// Crop the start and end of dimensions of out
Array<Integer> begin_idx, end_idx, strides;
Array<PrimExpr> begin_idx, end_idx, strides;
for (size_t i = 0; i < r_p_shape.size(); ++i) {
strides.push_back(Integer(1));
if (i > 0 && i <= num_block_dims) {
Expand Down
43 changes: 34 additions & 9 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,17 +598,42 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
*
* \return A Tensor whose op member is the split operation
*/
inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
const Array<Integer>& strides, std::string slice_mode = "end",
std::string name = "T_strided_slice", std::string tag = kInjective) {
inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string slice_mode = "end", std::string name = "T_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
// Quick path for dynamic shape strided slice.
// This is for ease of use to dynamice strided slice in topi.
bool is_static = IsConstIntArray(x->shape);
is_static &= IsConstIntArray(begin);
is_static &= IsConstIntArray(end);
is_static &= IsConstIntArray(strides);

Array<PrimExpr> out_shape;
if (!is_static) {
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides[i] + begin[i]);
}
return x(real_indices);
},
name, tag);
}

// Setup the ranges.
// NOTE: this code duplicates the shape inference logic relay.op
// Consider to refactor in the future.
std::vector<int64_t> stride_vec(src_tensor_dim, 1);
for (size_t i = 0; i < strides.size(); ++i) {
ICHECK(strides[i].defined());
stride_vec[i] = strides[i]->value;
stride_vec[i] = GetConstInt(strides[i]);
}

const int64_t max_range = std::numeric_limits<int64_t>::max();
Expand All @@ -619,7 +644,7 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
// value=None
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(begin[i]->value);
begin_vec.push_back(GetConstInt(begin[i]));
}
}
for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
Expand All @@ -633,20 +658,20 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
if (!end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (slice_mode == "size") {
if (end[i]->value < 0) {
int64_t end_val = GetConstInt(end[i]);
if (end_val < 0) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
end_vec.push_back(begin_vec[i] + end[i]->value);
end_vec.push_back(begin_vec[i] + end_val);
}
} else {
end_vec.push_back(end[i]->value);
end_vec.push_back(GetConstInt(end[i]));
}
}
for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}
// Compute
Array<PrimExpr> out_shape;
Array<PrimExpr> begin_expr;
Array<PrimExpr> strides_expr;

Expand Down
20 changes: 11 additions & 9 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,27 +479,28 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
name="topk_gpu",
tag="topk_gpu",
)
if k < 1:
if isinstance(k, int) and k < 1:
if ret_type == "indices":
return output[1]
return output
beg = [0] * ndim
end = []
strides = [1] * ndim
for i in range(ndim):
if i == axis:
end.append(k)
end.append(k if isinstance(k, int) else tvm.te.size_var("dim"))
else:
end.append(data.shape[i])
if ret_type == "both":
values_out, indices_out = output
values_out = strided_slice(values_out, beg, end)
indices_out = strided_slice(indices_out, beg, end)
values_out = strided_slice(values_out, beg, end, strides)
indices_out = strided_slice(indices_out, beg, end, strides)
output = [values_out, indices_out]
elif ret_type == "values":
output = [strided_slice(output, beg, end)]
output = [strided_slice(output, beg, end, strides)]
else: # ret_type == "indices"
indices_out = output[1]
output = [strided_slice(indices_out, beg, end)]
output = [strided_slice(indices_out, beg, end, strides)]
return output


Expand Down Expand Up @@ -561,10 +562,11 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
tag="topk_gpu",
)

if k > 0:
if not isinstance(k, int) or k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k]
out = [strided_slice(o, beg, end) for o in out]
end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")]
strides = [1] * ndim
out = [strided_slice(o, beg, end, strides) for o in out]

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
Expand Down
19 changes: 14 additions & 5 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,7 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
ICHECK(param != nullptr);
Array<Integer> begin, end, strides;
Array<PrimExpr> begin_expr, end_expr, strides_expr;
begin = param->begin.value();
end = param->end.value();
strides = param->strides.value();
Expand All @@ -2392,8 +2393,6 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
Array<PrimExpr> begin_expr;
Array<PrimExpr> strides_expr;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
Expand All @@ -2414,8 +2413,19 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
return input(real_indices);
},
std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})};
} else {
for (size_t i = 0; i < begin.size(); ++i) {
begin_expr.push_back(begin[i]);
}
for (size_t i = 0; i < end.size(); ++i) {
end_expr.push_back(end[i]);
}
for (size_t i = 0; i < strides.size(); ++i) {
strides_expr.push_back(strides[i]);
}
}
return Array<te::Tensor>{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)};
return Array<te::Tensor>{
topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)};
}

// Positional relay function to create StridedSlice operator used by frontend FFI.
Expand Down Expand Up @@ -2731,8 +2741,7 @@ Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>&
<< topi::GetConstInt(src_shape[axis]);
}
}
return Array<te::Tensor>{topi::strided_slice(inputs[0], GetIntArray(begin_idx),
GetIntArray(end_idx), GetIntArray(strides), "end")};
return Array<te::Tensor>{topi::strided_slice(inputs[0], begin_idx, end_idx, strides, "end")};
}

TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike);
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/dyn/test_dynamic_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from tvm import relay
import tvm.testing

# TODO(mbrookhart): Enable when we can get it working
# @tvm.testing.uses_gpu

@tvm.testing.uses_gpu
def test_dynamic_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
Expand Down
10 changes: 3 additions & 7 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,15 +815,11 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):
else:
ref_out = sorted[0:kval]

for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(*in_vals)
tvm.testing.assert_allclose(result.asnumpy(), ref_out)

# TODO(@zhiics) Fix topk cuda schedule for dynamic inputs
# check_result(in_vals, mod, ref_out)
check_result(in_vals, mod, ref_out)


# TODO(kevinthesun): enable this test when Thrust is available in ci.
# @tvm.testing.uses_gpu
def test_any_topk():
verify_any_topk(any_dims(1), 5, (10,), "float32")
verify_any_topk(any_dims(2), 2, (6, 3), "int32")
Expand Down

0 comments on commit 3cbb44d

Please sign in to comment.