Skip to content

Commit

Permalink
refactor compute
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 80442f8 commit 9a79560
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 41 deletions.
70 changes: 30 additions & 40 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,25 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
return result;
}

// inline te::Tensor strided_slice_compute_common() {}
inline te::Tensor strided_slice_compute_common(const te::Tensor& x,
const Array<PrimExpr>& out_shape,
const Array<PrimExpr>& begin,
const Array<PrimExpr>& strides,
const Array<Integer>& axes, const std::string& name,
const std::string& tag) {
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
},
name, tag);
}

/*!
* \brief strided_slice of a tensor with dynamic begin/end/stride
Expand Down Expand Up @@ -597,63 +615,48 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b

inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice",
std::string name = "T_dynamic_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
ICHECK_EQ(begin.size(), src_tensor_dim);
ICHECK_EQ(end.size(), src_tensor_dim);
ICHECK_EQ(strides.size(), src_tensor_dim);

Array<PrimExpr> out_shape;
Array<Integer> axes;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
axes.push_back(i);
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides[i] + begin[i]);
}
return x(real_indices);
},
name, tag);
return strided_slice_compute_common(x, out_shape, begin, strides, axes, name, tag);
}

inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array<Integer>& begin,
inline Tensor strided_slice_dynamic_input(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = input->shape.size();
size_t src_tensor_dim = x->shape.size();
ICHECK(begin.size() == src_tensor_dim)
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
Array<PrimExpr> begin_expr, end_expr, strides_expr;
Array<Integer> axes;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
begin_i += topi::detail::GetConstInt(input->shape[i]);
begin_i += topi::detail::GetConstInt(x->shape[i]);
}
begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
(i < strides.size() ? strides[i]->value : 1)));
axes.push_back(i);
}
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
return input(real_indices);
},
std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective});
return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag);
}

inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
Expand Down Expand Up @@ -689,7 +692,6 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
std::vector<int64_t> end_vec;
for (size_t i = 0; i < end.size(); ++i) {
// allow end to be None

if (!end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (slice_mode == "size") {
Expand Down Expand Up @@ -740,19 +742,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i]));
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
}

return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
},
name, tag);
return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
*rv = strided_slice_dynamic_input(x, begin_static, end_static, strides_static, slice_mode);
}
} else {
*rv = dynamic_strided_slice(x, begin, end, strides, slice_mode);
*rv = dynamic_strided_slice(x, begin, end, strides);
}
});

Expand Down

0 comments on commit 9a79560

Please sign in to comment.