Skip to content

Commit

Permalink
Fix gather kernel check (#7979)
Browse files Browse the repository at this point in the history
* fix reduce_sum scalar check bug

* fix gather kernel check bug

* Update dim_gather_kernel_util.h

* fix comment

* fix comment

* auto format by CI

* fix bug

* revert

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 17, 2022
1 parent a6e3d54 commit 6e9431d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
7 changes: 4 additions & 3 deletions oneflow/user/kernels/dim_gather_kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ template<typename IN_T, typename IDX_T>
struct DimGatherFunctor<DeviceType::kCPU, IN_T, IDX_T> final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index, input,
output);
int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,
IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim,
index, input, output);
}
};

Expand Down
19 changes: 11 additions & 8 deletions oneflow/user/kernels/dim_gather_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ namespace user_op {
template<typename IN_T, typename IDX_T>
__global__ void DoCUDADimGather(const DimOpIndexNdHelper<IDX_T> input_nd_helper,
const DimOpIndexNdHelper<IDX_T> index_nd_helper, int ndim,
int64_t elem_cnt, int32_t dim, const IDX_T* index,
const IN_T* input, IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index, input,
output);
int64_t elem_cnt, int32_t dim_length, int32_t dim,
const IDX_T* index, const IN_T* input, IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index,
input, output);
}

template<typename IDX_T, typename IN_T>
struct DimGatherFunctor<DeviceType::kCUDA, IN_T, IDX_T> final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,
IN_T* output) {
RUN_CUDA_KERNEL((DoCUDADimGather<IN_T, IDX_T>), stream, BlocksNum4ThreadsNum(elem_cnt),
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index, input, output);
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input,
output);
}
};

Expand All @@ -45,9 +47,10 @@ template<typename IDX_T>
struct DimGatherFunctor<DeviceType::kCUDA, float16, IDX_T> final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const float16* input, float16* output) {
int32_t dim_length, int32_t dim, const IDX_T* index, const float16* input,
float16* output) {
RUN_CUDA_KERNEL((DoCUDADimGather<half, IDX_T>), stream, BlocksNum4ThreadsNum(elem_cnt),
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index,
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index,
reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output));
}
};
Expand Down
13 changes: 10 additions & 3 deletions oneflow/user/kernels/dim_gather_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ template<DeviceType device_type, typename IN_T, typename IDX_T>
struct DimGatherFunctor final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output);
int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,
IN_T* output);
};

template<typename IN_T, typename IDX_T>
OF_DEVICE_FUNC void DoDimGather(const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim,
int64_t elem_cnt, int32_t dim, const IDX_T* index,
const IN_T* input, IN_T* output) {
int64_t elem_cnt, int32_t dim_length, int32_t dim,
const IDX_T* index, const IN_T* input, IN_T* output) {
XPU_1D_KERNEL_LOOP(index_offset, elem_cnt) {
IDX_T coordinate[kDimGatherMaxDimCount] = {0};
const IDX_T x = index[index_offset];
#ifdef __CUDA_ARCH__
assert(x < dim_length && "gather index is out of bounds");
#else
CHECK_LE(x, dim_length) << "RuntimeError: index " << x << " is out of bounds for dimension "
<< dim << " with size " << dim_length;
#endif
index_nd_helper.OffsetToNdIndex(index_offset, coordinate, ndim);
coordinate[dim] = x;

Expand Down
7 changes: 3 additions & 4 deletions oneflow/user/kernels/dim_gather_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ class DimGatherKernel final : public user_op::OpKernel {
DimOpIndexNdHelper<IDX_T> input_nd_helper(shape_vec.data(), ndim);
shape2dims(index_tensor->shape());
DimOpIndexNdHelper<IDX_T> index_nd_helper(shape_vec.data(), ndim);

DimGatherFunctor<device_type, IN_T, IDX_T>()(ctx->stream(), input_nd_helper, index_nd_helper,
ndim, index_tensor->shape().elem_cnt(), dim, index,
input, output);
DimGatherFunctor<device_type, IN_T, IDX_T>()(
ctx->stream(), input_nd_helper, index_nd_helper, ndim, index_tensor->shape().elem_cnt(),
input_tensor->shape().At(dim), dim, index, input, output);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
Expand Down

0 comments on commit 6e9431d

Please sign in to comment.