Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX][TOPI][strided_slice] Fix topi.strided_slice output shape #17502

68 changes: 64 additions & 4 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
#include <unordered_set>
#include <vector>

#include "tvm/ir/expr.h"
#include "tvm/runtime/data_type.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/op.h"
#include "tvm/tir/var.h"

namespace tvm {
namespace topi {
Expand Down Expand Up @@ -635,6 +639,55 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
return result;
}

inline PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) {
auto idx_var = index.as<tvm::tir::VarNode>();
auto extent_var = extent.as<tvm::tir::VarNode>();

if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) {
return index;
}

PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0);
PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent);

if (!(index->IsInstance<tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
index = tvm::if_then_else(index < 0, index + extent, index);
}

return tvm::min(tvm::max(index, begin_range), end_range);
}

inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
int64_t begin_range = stride < 0 ? -1 : 0;
int64_t end_range = stride < 0 ? extent - 1 : extent;
if (index < 0) {
index += extent;
}
return std::min(std::max(index, begin_range), end_range);
}

inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) {
if (index->IsInstance<tvm::IntImmNode>() && extent->IsInstance<tvm::IntImmNode>() &&
stride->IsInstance<tvm::IntImmNode>()) {
return tvm::IntImm(
tvm::DataType::Int(64),
StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride)));
}
return DynamicCanonicalizeIndex(index, extent, stride);
}

inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent,
bool assume_inbound = true) {
if (assume_inbound) {
return ceildiv(end - begin, stride);
} else {
begin = CanonicalizeIndex(begin, extent, stride);
end = CanonicalizeIndex(end, extent, stride);
return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
ceildiv(end - begin, stride));
}
}

/*!
* \brief strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
*
Expand All @@ -644,14 +697,15 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param axes Specifies which axes will be updated.
* \param assume_inbound Specifies if all indices are assumed to be inbound
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the dynamic_strided_slice operation
*/
inline Tensor dynamic_strided_slice_with_axes(
const Tensor& x, const Array<PrimExpr>& begin, const Array<PrimExpr>& end,
const Array<PrimExpr>& strides, const Array<Integer>& axes,
const Array<PrimExpr>& strides, const Array<Integer>& axes, bool assume_inbound = true,
std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) {
const size_t src_tensor_dim = x->shape.size();
ICHECK_EQ(begin.size(), end.size());
Expand All @@ -669,7 +723,8 @@ inline Tensor dynamic_strided_slice_with_axes(
Array<PrimExpr> out_shape = x->shape;
for (size_t i = 0; i < begin.size(); i++) {
int axis = axes[i]->value;
PrimExpr new_shape = analyzer.Simplify(ceildiv(end[i] - begin[i], strides[i]));
PrimExpr new_shape =
analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound));
out_shape.Set(axis, new_shape);
}

Expand Down Expand Up @@ -697,13 +752,15 @@ inline Tensor dynamic_strided_slice_with_axes(
* \param end Indices indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param assume_inbound Specifies if all indices are assumed to be inbound
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the dynamic_strided_slice operation
*/
inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
bool assume_inbound = true,
std::string name = "T_dynamic_strided_slice",
std::string tag = kInjective) {
const size_t src_tensor_dim = x->shape.size();
Expand All @@ -721,7 +778,8 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begi
// Check ProducerLoad to keep backward compatibility for Relay.
if (!begin[i]->IsInstance<ProducerLoadNode>() && !end[i]->IsInstance<ProducerLoadNode>() &&
!strides[i]->IsInstance<ProducerLoadNode>()) {
out_shape.push_back(analyzer.Simplify(ceildiv(end[i] - begin[i], strides[i])));
out_shape.push_back(
analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound)));
} else {
out_shape.push_back(tvm::tir::Var("dim"));
}
Expand Down Expand Up @@ -755,13 +813,15 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begi
* \param end Indices indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param assume_inbound Specifies if all indices are assumed to be inbound
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the dynamic_strided_slice operation
*/
inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin,
const te::Tensor& end, const te::Tensor& strides,
bool assume_inbound = true,
std::string name = "T_strided_slice_dynamic",
std::string tag = topi::kInjective) {
DataType index_dtype = begin->shape[0]->dtype;
Expand All @@ -776,7 +836,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
end_expr.push_back(end(ind));
strides_expr.push_back(strides(ind));
}
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, assume_inbound, name, tag);
}

/*!
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ def _redistribute_replica_to_shard(_bb: BlockBuilder, call: Call) -> Expr:
axes=[axis],
begin=[worker_id_symbol * split_axis_size // num_workers],
end=[(worker_id_symbol + 1) * split_axis_size // num_workers],
assume_inbound=True,
)
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _relax_tuple_to_tir(relax_tuple):
strides,
axes,
slice_mode="end",
assume_inbound=call.attrs.assume_inbound,
)


Expand Down
7 changes: 5 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0):
return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis)


def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end", assume_inbound=True):
"""Slice of an array.

Parameters
Expand Down Expand Up @@ -200,6 +200,9 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
the sizeof a slice starting at the location specified by begin. If end[i]
is -1, all remaining elements in that dimension are included in the slice.

assume_inbound: bool, optional
A flag to indicate if all indices are assumed to be inbound

Returns
-------
ret : tvm.te.Tensor
Expand All @@ -223,7 +226,7 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
strides = []
if axes is None:
axes = []
return cpp.strided_slice(a, begin, end, strides, axes, slice_mode)
return cpp.strided_slice(a, begin, end, strides, axes, slice_mode, assume_inbound)


def dynamic_strided_slice(a, begin, end, strides, output_shape):
Expand Down
26 changes: 2 additions & 24 deletions src/relax/op/tensor/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "index.h"

#include <tvm/relax/analysis.h>
#include <tvm/topi/transform.h>

#include <algorithm>
#include <optional>
Expand Down Expand Up @@ -171,29 +172,6 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional<Expr> strid

TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);

inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) {
// Handle Python-style negative indices
index = if_then_else(index < 0, index + extent, index);
// Clamp the result to valid indices
PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0);
PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent);
index = tvm::min(tvm::max(index, lower_bound), upper_bound);

return index;
}

PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent,
bool assume_inbound) {
if (assume_inbound) {
return ceildiv(end - begin, stride);
} else {
begin = CanonicalizeIndex(begin, extent, stride);
end = CanonicalizeIndex(end, extent, stride);
return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
ceildiv(end - begin, stride));
}
}

/* \brief Helper function to unpack a relax::Tuple
*
* A `relax::Tuple` may be provided to an operator as an in-line
Expand Down Expand Up @@ -424,7 +402,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx
PrimExpr end = end_tuple[i];

PrimExpr output_dim =
GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound);
topi::GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound);

arith::Analyzer* analyzer = ctx->GetAnalyzer();
std::optional<With<arith::ConstraintContext>> context;
Expand Down
9 changes: 7 additions & 2 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include <tvm/topi/transform.h>
#include <tvm/topi/utils.h>

#include <iostream>

#include "tvm/ir/expr.h"

namespace tvm {
namespace topi {

Expand Down Expand Up @@ -179,6 +183,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
Array<Integer> axes = args[4];
bool assume_inbound = args[6];
if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) &&
IsConstIntArray(x->shape)) {
Array<Integer> begin_static = args[1];
Expand All @@ -192,9 +197,9 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
}
} else {
if (axes.size()) {
*rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes);
*rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound);
} else {
*rv = dynamic_strided_slice(x, begin, end, strides);
*rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound);
}
}
});
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_op_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import TVMError
from tvm.ir import Op, VDevice
from tvm.script import ir as I, relax as R, tir as T
import numpy as np


def test_op_correctness():
Expand Down Expand Up @@ -1010,5 +1011,44 @@ def strided_slice(
tvm.ir.assert_structural_equal(expected, after)


def test_legalize_dynamic_begin_inf_end():
"""relax.op.strided_slice FLegalize must support dynamic begin/end"""

@I.ir_module
class before:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use # fmt: off and # fmt: on to disable black for specific code (e.g. TVMScript)

@R.function
def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1, 16)):
index = T.int64()
return R.strided_slice(
A, [0], [index], [T.int64(np.iinfo(np.int64).max)], assume_inbound=False
)

# fmt: off
@I.ir_module
class expected:
@T.prim_func(private=True)
def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64):
T.func_attr({"tir.noalias": T.bool(True)})
T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)))
# with T.block("root"):
for ax0, ax1 in T.grid(T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)):
with T.block("T_dynamic_strided_slice_with_axes"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0 + index, v_ax1])
T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1])
T_dynamic_strided_slice_with_axes[v_ax0, v_ax1] = A[v_ax0 + index, v_ax1]

@R.function
def main(A: R.Tensor((16, 16), dtype="float32"), B: R.Shape(["index"])) -> R.Tensor(("T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0)", 16), dtype="float32"):
index = T.int64()
cls = expected
gv = R.call_tir(cls.strided_slice, (A,), out_sinfo=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index]))
return gv
# fmt: on

after = tvm.relax.transform.LegalizeOps()(before)
tvm.ir.assert_structural_equal(expected, after)


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class StridedSlice:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), "float32"):
n = T.int64()
gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3])
gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3], assume_inbound=True)
return gv

@I.ir_module
Expand Down
Loading