diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 5c3734b6e4c0..eb1ad8bd622b 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -550,6 +550,8 @@ inline Array split(const Tensor& x, Array split_indices, int a return result; } +// inline te::Tensor strided_slice_compute_common() {} + /*! * \brief strided_slice of a tensor with dynamic begin/end/stride * @@ -657,7 +659,7 @@ inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array& begin, const Array& end, const Array& strides, const Array& axes, std::string slice_mode = "end", - std::string name = "T_strided_slice_dynamic_input", + std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { size_t src_tensor_dim = x->shape.size(); @@ -703,10 +705,12 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg } // Compute - Array begin_expr; - Array strides_expr; - Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(x->shape[i]); + } + + Array begin_expr, strides_expr; for (size_t i = 0; i < axes.size(); ++i) { int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; ICHECK(x->shape[axes[i]]->IsInstance()) @@ -734,7 +738,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); strides_expr.push_back( make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); - out_shape.push_back(slice_size); + out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); } return te::compute( @@ -743,12 +747,12 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg Array real_indices; for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < axes.size(); ++i) { - PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i]; + PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i]; real_indices.Set(axes[i], ind); } return x(real_indices); }, - std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective}); + name, tag); } /*! diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6671abb2e263..40958116517f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1359,7 +1359,7 @@ def has_static_axes(): else: strides_np = steps.data.asnumpy().astype("int64") - # return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np)) + return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np)) # Update the starts and ends according to axes if required. if axes is not None: