diff --git a/oneflow/api/python/framework/tensor_functions.cpp b/oneflow/api/python/framework/tensor_functions.cpp index 268c877dee8..2dbfd4a3a02 100644 --- a/oneflow/api/python/framework/tensor_functions.cpp +++ b/oneflow/api/python/framework/tensor_functions.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/common/shape_vec.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/common/shape.h" +#include "oneflow/core/common/wrap_dim_utils.h" namespace oneflow { namespace one { @@ -329,10 +330,7 @@ static PyObject* PyTensorObject_size(PyObject* self, PyObject* args, PyObject* k if (idx_obj == NULL || idx_obj == Py_None) return TensorSize_NewFromShape(*shape); int64_t idx = PyLong_AsLongLong(idx_obj); int64_t ndim = shape->NumAxes(); - - CHECK_OR_THROW(idx >= -ndim && idx < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << idx << ")"; + idx = CHECK_JUST(maybe_wrap_dim(idx, ndim)); idx = idx < 0 ? idx + ndim : idx; return PyLong_FromLongLong(shape->At(idx)); END_HANDLE_ERRORS diff --git a/oneflow/core/common/wrap_dim_utils.h b/oneflow/core/common/wrap_dim_utils.h new file mode 100644 index 00000000000..929b203cf45 --- /dev/null +++ b/oneflow/core/common/wrap_dim_utils.h @@ -0,0 +1,40 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +// align with pytorch: `c10/core/WrapDimMinimal.h` +static inline Maybe maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, + bool wrap_scalar = true) { + if (dim_post_expr <= 0) { + if (!wrap_scalar) { + return Error::RuntimeError() + << "dimension specified as " << dim << " but tensor has no dimensions"; + } + dim_post_expr = 1; // this will make range [-1, 0] + } + + int64_t min = -dim_post_expr; + int64_t max = dim_post_expr - 1; + if (dim < min || dim > max) { + return Error::IndexError() << "Dimension out of range (expected to be in range of [" << min + << ", " << max << "], but got " << dim << ")"; + } + if (dim < 0) dim += dim_post_expr; + return dim; +} +} // namespace oneflow diff --git a/oneflow/core/framework/tensor_methods.cpp b/oneflow/core/framework/tensor_methods.cpp index b4479136f93..7ed41a652f8 100644 --- a/oneflow/core/framework/tensor_methods.cpp +++ b/oneflow/core/framework/tensor_methods.cpp @@ -24,6 +24,7 @@ limitations under the License. #include "oneflow/core/register/ofblob.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/ep/include/device_manager_registry.h" +#include "oneflow/core/common/wrap_dim_utils.h" namespace oneflow { namespace one { @@ -400,12 +401,7 @@ Maybe Transpose(const std::shared_ptr& input, const std::vector< CHECK_EQ_OR_RETURN(permute.size(), ndim) << "permute size should be equal to input tensor's ndim, but got " << permute.size(); auto positive_perm = permute; - for (auto i = 0; i < positive_perm.size(); i++) { - if (positive_perm[i] < 0) { positive_perm[i] += ndim; } - CHECK_OR_RETURN(positive_perm[i] >= 0 && positive_perm[i] < ndim) - << "IndexError: Dimension out of range (expected to be in range of [" << -ndim << "," - << ndim << " ) but got " << positive_perm[i]; - } + for (auto i = 0; i < positive_perm.size(); i++) { JUST(maybe_wrap_dim(positive_perm[i], ndim)); } DimVector target_dims(ndim); DimVector stride_vec(ndim); diff --git a/oneflow/core/functional/function_library.h b/oneflow/core/functional/function_library.h index 13d59a0434b..570edffb3bc 100644 --- a/oneflow/core/functional/function_library.h +++ b/oneflow/core/functional/function_library.h @@ -17,6 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_FUNCTIONAL_FUNCTION_LIBRARY_H_ #include "oneflow/core/common/util.h" +#include "oneflow/core/common/wrap_dim_utils.h" #include "oneflow/core/functional/packed_functor.h" #include "oneflow/core/common/stride.h" #include "oneflow/core/framework/tensor_methods.h" diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp index 5f000756846..3a604b3ebf1 100644 --- a/oneflow/core/functional/impl/activation_functor.cpp +++ b/oneflow/core/functional/impl/activation_functor.cpp @@ -226,9 +226,7 @@ class GluFunctor { const auto ndim = input->ndim(); CHECK_GT_OR_RETURN(ndim, 0) << Error::RuntimeError() << "glu does not support scalars because halving size must be even"; - CHECK_OR_RETURN(dim >= -ndim && dim < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dim << ")"; + dim = JUST(maybe_wrap_dim(dim, ndim)); if (dim < 0) { dim += ndim; } int64_t nc = input->dim(dim); CHECK_EQ_OR_RETURN(nc % 2, 0) << Error::RuntimeError() @@ -332,10 +330,7 @@ class SoftmaxFunctorBase { int64_t dim_ = dim ? JUST(dim) : get_dim(); if (dim_ < 0) { dim_ += num_axes; } - CHECK_OR_RETURN(dim_ >= -num_axes && dim_ < num_axes) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -num_axes << ", " << num_axes - 1 << "], but got " << dim_ << ")"; - + dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); if (dim_ != num_axes - 1) { std::vector input_perm(input_shape->dim_vec().size(), 0); for (size_t i = 1; i < input_perm.size(); ++i) { input_perm[i] = i; } diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 9badf700e79..c00d2ac8d20 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -68,9 +68,7 @@ class ArgMaxFunctor { int new_dim = JUST(dim); const int32_t ndims = input->shape()->NumAxes(); - CHECK_OR_RETURN(new_dim >= -ndims && new_dim < ndims) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndims - << "," << ndims << " ] but got " << new_dim << ")"; + new_dim = JUST(maybe_wrap_dim(new_dim, ndims)); if (new_dim < 0) { new_dim += ndims; } const auto do_cast = [&](const std::shared_ptr& x) -> Maybe { return Cast(x, JUST(dtype), /*pin_memory=*/false); @@ -469,10 +467,7 @@ class ConcatFunctor { int64_t ndim = inputs[0]->ndim(); int64_t max_dim_size = 0; CHECK_GE_OR_RETURN(ninput, 1) << Error::RuntimeError() << "inputs size must greater than 0"; - CHECK_OR_RETURN((-(ndim) <= dim) && (dim <= (ndim - 1))) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dim << ")"; - if (dim < 0) { axis += ndim; } + axis = JUST(maybe_wrap_dim(axis, ndim)); const std::shared_ptr& shape = inputs[0]->shape(); for (const auto& input : inputs) { @@ -526,10 +521,7 @@ class StackFunctor { const int64_t ninput = inputs.size(); int64_t ndims = inputs[0]->ndim(); int64_t stack_dim = dim; - if (dim < 0) { stack_dim = stack_dim + ndims + 1; } - CHECK_OR_RETURN(stack_dim >= 0 && stack_dim <= ndims) - << Error::IndexError() << "Dimension out of range (expected in range of [" << -ndims - 1 - << ", " << ndims << "], but got " << stack_dim << ")"; + stack_dim = JUST(maybe_wrap_dim(stack_dim, ndims + 1)); if (ninput == 1) { return ExpandDims(inputs[0], dim); } const std::shared_ptr& first_in_shape = inputs[0]->shape(); for (const auto& input : inputs) { @@ -666,9 +658,7 @@ class ExpandDimsFunctor { Maybe operator()(const std::shared_ptr& input, const int32_t& dim) const { int32_t expand_dim = dim; const int32_t ndim = input->shape()->NumAxes(); - CHECK_OR_RETURN(-(ndim + 1) <= dim && dim <= ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -(ndim + 1) << ", " << ndim << "], but got " << dim << ")"; + JUST(maybe_wrap_dim(dim, ndim + 1)); if (dim < 0) { expand_dim = dim + ndim + 1; } MutableAttrMap attrs; JUST(attrs.SetAttr("axis", expand_dim)); @@ -695,10 +685,7 @@ class SqueezeFunctor { if (dim.has_value()) { std::vector dims = *JUST(dim); for (int32_t dim_i : dims) { - CHECK_OR_RETURN((dim_i >= -ndim) && (dim_i <= ndim - 1)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -ndim << "," << ndim - 1 << "], but got " << dim_i << ")"; - if (dim_i < 0) { dim_i += ndim; } + dim_i = JUST(maybe_wrap_dim(dim_i, ndim)); if (x->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); } } } else { @@ -776,16 +763,8 @@ class DimGatherFunctor { << Error::RuntimeError() << "gather(): Expected dtype int32 or int64 for index"; CHECK_EQ_OR_RETURN(sparse_grad, false) << Error::RuntimeError() << "Only support bool = False for now!"; - if (index->ndim() > 0) { - CHECK_LT_OR_RETURN(dim, index->ndim()) - << Error::RuntimeError() << "Dimension out of range (expected to be in range of [" - << -index->ndim() << ", " << index->ndim() - 1 << "], but got " << dim << ")"; - } else { - // For 0-dim Tensor - CHECK_LE_OR_RETURN(dim, index->ndim()) - << Error::RuntimeError() - << "Dimension out of range (expected to be in range of [-1, 0], but got " << dim << ")"; - } + + JUST(maybe_wrap_dim(dim, index->ndim())); if (input->ndim() > 0 && index->ndim() > 0) { CHECK_EQ_OR_RETURN(input->ndim(), index->ndim()) << Error::RuntimeError() @@ -1300,11 +1279,8 @@ class NarrowFunctor { const int64_t ndim = input->shape()->NumAxes(); CHECK_GT_OR_RETURN(ndim, 0) << Error::RuntimeError() << "narrow() cannot be applied to a 0-dim tensor."; - CHECK_OR_RETURN((-ndim <= dim) && (dim <= ndim - 1)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dim << ")"; - if (narrow_dim < 0) { narrow_dim += ndim; } - const int64_t dim_length = input->shape()->At(narrow_dim); + narrow_dim = JUST(maybe_wrap_dim(narrow_dim, ndim)); + int64_t dim_length = input->shape()->At(narrow_dim); CHECK_OR_RETURN((-dim_length <= start) && (start <= dim_length)) << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim << ", " << ndim << "], but got " << start << ")"; @@ -1940,16 +1916,10 @@ class DiagonalFunctor { Maybe operator()(const std::shared_ptr& x, const int32_t& offset, const int32_t& dim1, const int32_t& dim2) const { int64_t ndims = x->shape()->NumAxes(); - - CHECK_OR_RETURN(dim1 >= -ndims && dim1 < ndims) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndims - << ", " << ndims - 1 << "], but got " << dim1 << ")"; - CHECK_OR_RETURN(dim2 >= -ndims && dim2 < ndims) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndims - << ", " << ndims - 1 << "], but got " << dim2 << ")"; - - const int32_t p_dim1 = dim1 >= 0 ? dim1 : dim1 + ndims; - const int32_t p_dim2 = dim2 >= 0 ? dim2 : dim2 + ndims; + int32_t p_dim1 = dim1; + int32_t p_dim2 = dim2; + p_dim1 = JUST(maybe_wrap_dim(p_dim1, ndims)); + p_dim2 = JUST(maybe_wrap_dim(p_dim2, ndims)); CHECK_NE_OR_RETURN(p_dim1, p_dim2) << Error::RuntimeError() << "diagonal dimensions cannot be identical " << dim1 << ", " << dim2; @@ -2362,10 +2332,7 @@ class SplitFunctor { Maybe operator()(const std::shared_ptr& x, const int64_t& split_size_or_sections, const int64_t& dim) const { int64_t axis = dim; - if (axis < 0) { axis += x->ndim(); } - CHECK_OR_RETURN(axis >= 0 && axis < x->ndim()) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -x->ndim() << ", " << x->ndim() - 1 << "], but got " << axis << ")"; + axis = JUST(maybe_wrap_dim(axis, x->ndim())); CHECK_GE_OR_RETURN(split_size_or_sections, 0) << Error::RuntimeError() << "split expects split_size be non-negative, but got split_size=" << split_size_or_sections; @@ -2389,10 +2356,7 @@ class UnbindFunctor { Maybe operator()(const std::shared_ptr& x, const int64_t& dim) const { int32_t axis = dim; const int32_t ndim = x->ndim(); - if (axis < 0) { axis += ndim; } - CHECK_OR_RETURN((dim >= -ndim) && (dim < ndim)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << "," << ndim - 1 << "], but got " << dim << ")"; + axis = JUST(maybe_wrap_dim(axis, ndim)); int32_t dim_size = x->shape()->At(axis); std::shared_ptr chunk_res = JUST(functional::Chunk(x, dim_size, axis)); TensorTuple unbinds(dim_size); @@ -2415,10 +2379,7 @@ class ChunkFunctor { << "chunk expects at least a 1-dimensional tensor."; CHECK_OR_RETURN(chunks > 0) << Error::RuntimeError() << "chunk expects `chunks` to be greater than 0, got: " << chunks; - CHECK_OR_RETURN(-ndim <= dim && dim <= (ndim - 1)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dim << ")"; - if (dim < 0) { infferd_dim += ndim; } + infferd_dim = JUST(maybe_wrap_dim(infferd_dim, ndim)); const auto dim_size = x->shape()->At(infferd_dim); int64_t split_size = (dim_size + chunks - 1) / chunks; @@ -2470,10 +2431,7 @@ class SplitWithSizeFunctor { const std::vector& split_size_or_sections, const int64_t& dim) const { int64_t axis = dim; - if (axis < 0) { axis += x->ndim(); } - CHECK_OR_RETURN(axis >= 0 && axis < x->ndim()) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -x->ndim() << ", " << x->ndim() - 1 << "], but got " << axis << ")"; + axis = JUST(maybe_wrap_dim(axis, x->ndim())); int64_t dim_size = x->shape()->At(axis); int64_t num_splits = split_size_or_sections.size(); TensorTuple splits(num_splits); @@ -2657,11 +2615,7 @@ class IndexSelectFunctor { CHECK_EQ_OR_RETURN(index_dtype_flag, true) << Error::RuntimeError() << "index_select(): Expected dtype int32 or int64 for index"; int64_t new_dim = dim; - if (dim < 0) { new_dim += input_num_axes; } - CHECK_LE_OR_RETURN(new_dim, input_num_axes) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -input_num_axes << ", " << input_num_axes - 1 << "], but got " << new_dim << ")"; - + new_dim = JUST(maybe_wrap_dim(new_dim, input_num_axes)); return JUST(functional::Gather(input, index, new_dim)); } }; @@ -2982,10 +2936,7 @@ class RepeatInterLeaveIntFunctor { int32_t dim_ = JUST(dim); const auto& input_shape = input->shape(); const int64_t& num_axes = input_shape->NumAxes(); - if (dim_ < 0) { dim_ += num_axes; } - CHECK_OR_RETURN(dim_ >= -num_axes && dim_ < num_axes) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -num_axes << ", " << num_axes - 1 << "], but got " << dim_ << ")"; + dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); std::shared_ptr repeats_expand = JUST( Expand(JUST(Constant(Shape{1}, Scalar(repeats), DType::Int32(), JUST(input->device()))), Shape{input->shape()->At(dim_)})); @@ -3030,10 +2981,7 @@ class RepeatInterLeaveTensorFunctor { int32_t dim_ = dim; const auto& input_shape = input->shape(); const int64_t& num_axes = input_shape->NumAxes(); - if (dim_ < 0) { dim_ += num_axes; } - CHECK_OR_RETURN(dim_ >= -num_axes && dim_ < num_axes) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -num_axes << ", " << num_axes - 1 << "], but got " << dim_ << ")"; + dim_ = JUST(maybe_wrap_dim(dim_, num_axes)); CHECK_OR_RETURN(repeats_shape->At(0) == input->shape()->At(dim_)) << Error::RuntimeError() << "repeats must have the same size as input along dim"; std::shared_ptr cumsum = JUST(Cumsum(repeats, 0, DType::Int32())); diff --git a/oneflow/core/functional/impl/common.cpp b/oneflow/core/functional/impl/common.cpp index b24e18cd121..11cf67a2ab9 100644 --- a/oneflow/core/functional/impl/common.cpp +++ b/oneflow/core/functional/impl/common.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/autograd/autograd_mode.h" +#include "oneflow/core/common/wrap_dim_utils.h" namespace oneflow { namespace one { @@ -39,14 +40,7 @@ Maybe> CheckAxis(const std::vector& axis, const in std::vector reduce_axis(naxis); std::vector axis_num(ndim); for (int32_t i = 0; i < naxis; i++) { - CHECK_OR_RETURN(axis[i] >= -ndim && axis[i] < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << axis[i] << ")"; - if (axis[i] < 0) { - reduce_axis[i] = axis[i] + ndim; - } else { - reduce_axis[i] = axis[i]; - } + reduce_axis[i] = JUST(maybe_wrap_dim(axis[i], ndim)); axis_num[reduce_axis[i]]++; CHECK_OR_RETURN(axis_num[reduce_axis[i]] < 2) << Error::RuntimeError() << "dim " << reduce_axis[i] diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index c200acb3c38..ab22624934a 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -372,12 +372,7 @@ class Max2Functor { const bool& keepdims) const { auto outputs = std::make_shared(2); int32_t axis = dim; - if (axis < -x->ndim() || axis >= x->ndim()) { - return Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -x->ndim() << ", " << x->ndim() - 1 << "], but got " << axis - << ")"; - } - if (axis < 0) { axis += x->ndim(); } + axis = JUST(maybe_wrap_dim(axis, x->ndim())); (*outputs)[0] = JUST(ReduceMax(x, {axis}, keepdims)); (*outputs)[1] = JUST(ArgMax(x, dim, keepdims, NullOpt)); return outputs; @@ -399,12 +394,7 @@ class Min2Functor { const bool& keepdims) const { auto outputs = std::make_shared(2); int32_t axis = dim; - if (axis < -x->ndim() || axis >= x->ndim()) { - return Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -x->ndim() << ", " << x->ndim() - 1 << "], but got " << axis - << ")"; - } - if (axis < 0) { axis += x->ndim(); } + axis = JUST(maybe_wrap_dim(axis, x->ndim())); (*outputs)[0] = JUST(ReduceMin(x, {axis}, keepdims)); (*outputs)[1] = JUST(ArgMin(x, dim, keepdims, NullOpt)); return outputs; @@ -419,13 +409,7 @@ class AminFunctor { const int32_t ndim = x->ndim(); std::vector& dims = *JUST(dim); - for (int i = 0; i < dims.size(); i++) { - if (dims[i] < -ndim || dims[i] >= ndim) { - return Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -ndim << ", " << ndim - 1 << "], but got " << dims[i] << ")"; - } - if (dims[i] < 0) { dims[i] += ndim; } - } + for (int i = 0; i < dims.size(); i++) { dims[i] = JUST(maybe_wrap_dim(dims[i], ndim)); } return ReduceMin(x, dims, keepdim); } }; @@ -438,13 +422,7 @@ class AmaxFunctor { const int32_t ndim = x->ndim(); std::vector& dims = *JUST(dim); - for (int i = 0; i < dims.size(); i++) { - if (dims[i] < -ndim || dims[i] >= ndim) { - return Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -ndim << ", " << ndim - 1 << "], but got " << dims[i] << ")"; - } - if (dims[i] < 0) { dims[i] += ndim; } - } + for (int i = 0; i < dims.size(); i++) { dims[i] = JUST(maybe_wrap_dim(dims[i], ndim)); } return ReduceMax(x, dims, keepdim); } }; @@ -820,18 +798,14 @@ class MedianWithIndicesFunctor { const bool& keepdim) const { MutableAttrMap attrs; int32_t axis = dim; - if (axis < -x->ndim() || axis >= x->ndim()) { - return Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -x->ndim() << ", " << x->ndim() - 1 << "], but got " << axis - << ")"; - } - if (axis < 0) { axis += x->ndim(); } + const int64_t ndim = x->ndim(); + axis = JUST(maybe_wrap_dim(axis, ndim)); std::shared_ptr tensor = x; if (x->dim(axis) == 0) { return Error::IndexError() << "IndexError: Expected reduction dim " << axis << " to have non-zero size."; } - if (axis != x->ndim() - 1) { + if (axis != ndim - 1) { tensor = JUST(functional::Squeeze( JUST(functional::Transpose2dim(JUST(functional::Unsqueeze(x, -1)), axis, -1)), std::vector({axis}))); @@ -898,10 +872,7 @@ class TransposeFunctor { // so copy it to local var and do modification. auto positive_perm = permute; for (auto i = 0; i < positive_perm.size(); i++) { - if (positive_perm[i] < 0) { positive_perm[i] += ndim; } - CHECK_OR_RETURN(positive_perm[i] >= 0 && positive_perm[i] < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << "," << ndim << " ) but got " << positive_perm[i] << ")"; + positive_perm[i] = JUST(maybe_wrap_dim(positive_perm[i], ndim)); } // currently, view only support eager and local mode if (view::IsViewApplicable(input)) { return JUST(view::Transpose(input, positive_perm)); } @@ -927,15 +898,8 @@ class Transpose2dimFunctor { int32_t dim_0 = dim0; int32_t dim_1 = dim1; - if (dim0 < 0) { dim_0 += ndim; } - if (dim1 < 0) { dim_1 += ndim; } - - CHECK_OR_RETURN(dim_0 >= 0 && dim0 < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dim_0 << ")"; - CHECK_OR_RETURN(dim_1 >= 0 && dim1 < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dim_1 << ")"; + dim_0 = JUST(maybe_wrap_dim(dim_0, ndim)); + dim_1 = JUST(maybe_wrap_dim(dim_1, ndim)); for (int32_t i = 0; i < ndim; ++i) { permute.emplace_back(i); } std::swap(permute[dim_0], permute[dim_1]); Shape shape(DimVector(permute.begin(), permute.end())); @@ -1682,10 +1646,7 @@ class SelectFunctor { const int32_t& index) const { int32_t ndim = input->ndim(); CHECK_OR_RETURN(ndim > 0) << "select() cannot be applied to a 0-dim tensor."; - CHECK_OR_RETURN((dim >= -ndim) && (dim < ndim)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << "," << ndim - 1 << "], but got " << dim << ")"; - int32_t pos_dim = dim >= 0 ? dim : dim + ndim; + int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim)); auto size = input->dim(pos_dim); CHECK_OR_RETURN((index >= -size) && (index < size)) << "Index out of range (expected to be in range of [" << -size << "," << size - 1 @@ -2049,9 +2010,7 @@ class VarianceFunctor { for (int i = 0; i < ndim; i++) { axis.emplace_back(i); } } else { std::vector& dims = *JUST(dim); - CHECK_GE_OR_RETURN(ndim, dims.size()) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << dims.size() << ")"; + JUST(maybe_wrap_dim(dims.size(), ndim)); std::sort(dims.begin(), dims.end()); axis.assign(dims.begin(), dims.end()); } @@ -2089,10 +2048,7 @@ class MovedimVecFunctor { std::vector is_used(ndim, false); FOR_RANGE(size_t, i, 0, perm.size()) { int32_t item = perm[i]; - if (item < 0) { item += ndim; } - CHECK_OR_RETURN(item >= -ndim && item < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << ", " << ndim - 1 << "], but got " << item << ")"; + item = JUST(maybe_wrap_dim(item, ndim)); CHECK_EQ_OR_RETURN(is_used[item], false) << "repeated dim in " << desc; is_used[item] = true; @@ -2159,10 +2115,7 @@ class TensorSplitVecFunctor { const std::vector& indices_or_sections, const int32_t& dim) const { int32_t ndim = input->ndim(); - CHECK_OR_RETURN((dim >= -ndim) && (dim < ndim)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << "," << ndim - 1 << "], but got " << dim << ")"; - int32_t pos_dim = dim >= 0 ? dim : dim + ndim; + int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim)); std::vector start(ndim, 0); std::vector stop(ndim); @@ -2190,12 +2143,9 @@ class TensorSplitIntFunctor { Maybe operator()(const std::shared_ptr& input, const int32_t& indices_or_sections, const int32_t& dim) const { int32_t ndim = input->ndim(); - CHECK_OR_RETURN((dim >= -ndim) && (dim < ndim)) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << "," << ndim - 1 << "], but got " << dim << ")"; + int32_t pos_dim = JUST(maybe_wrap_dim(dim, ndim)); CHECK_OR_RETURN(indices_or_sections > 0) << "number of sections must be larger than 0, got ," << indices_or_sections << ");"; - int32_t pos_dim = dim >= 0 ? dim : dim + ndim; const auto dim_size = input->dim(pos_dim); int64_t min_split_size = dim_size / indices_or_sections; @@ -2319,10 +2269,7 @@ class CumBaseFunctor { Maybe operator()(const std::shared_ptr& input, int64_t dim, const Optional>& dtype) const { auto ndim = input->ndim(); - if (dim < 0) { dim += ndim; } - CHECK_OR_RETURN(dim >= 0 && dim < ndim) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" << -ndim - << "," << ndim << " ) but got " << dim << ")"; + dim = JUST(maybe_wrap_dim(dim, ndim)); MutableAttrMap attrs; JUST(attrs.SetAttr("dim", dim)); diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index af0ff519dd4..532ef47be49 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -387,14 +387,8 @@ class TensorDotFunctor { std::vector dot_dims_a(dims_a.begin(), dims_a.end()); std::vector dot_dims_b(dims_b.begin(), dims_b.end()); for (int64_t i = 0; i < dot_dims_a.size(); i++) { - CHECK_OR_RETURN(dot_dims_a[i] >= -a->ndim() && dot_dims_a[i] < a->ndim()) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -a->ndim() << ", " << a->ndim() - 1 << "], but got " << dot_dims_a[i] << ")"; - CHECK_OR_RETURN(dot_dims_b[i] >= -b->ndim() && dot_dims_b[i] < b->ndim()) - << Error::IndexError() << "Dimension out of range (expected to be in range of [" - << -b->ndim() << ", " << b->ndim() - 1 << "], but got " << dot_dims_b[i] << ")"; - dot_dims_a[i] = dot_dims_a[i] < 0 ? dot_dims_a[i] + a->ndim() : dot_dims_a[i]; - dot_dims_b[i] = dot_dims_b[i] < 0 ? dot_dims_b[i] + b->ndim() : dot_dims_b[i]; + dot_dims_a[i] = JUST(maybe_wrap_dim(dot_dims_a[i], a->ndim())); + dot_dims_b[i] = JUST(maybe_wrap_dim(dot_dims_b[i], b->ndim())); } std::vector if_dot_dims_a(a->ndim(), false); std::vector if_dot_dims_b(b->ndim(), false);