Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gather kernel check #7979

Merged
merged 37 commits into from
Apr 17, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e6e6973
fix reduce_sum scalar check bug
BBuf Mar 22, 2022
a0abdd5
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 23, 2022
00522df
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 23, 2022
68e0e08
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 24, 2022
0b90f9b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 28, 2022
d81aa80
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 29, 2022
81954d7
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 29, 2022
c910fbe
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 29, 2022
32702cf
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 29, 2022
9fa0513
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 30, 2022
8402ce3
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Mar 31, 2022
e3bf835
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 1, 2022
7a46669
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 2, 2022
5786ed5
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 2, 2022
1f66c27
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 3, 2022
265c15a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 5, 2022
68f1baf
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 6, 2022
8f4ce10
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 7, 2022
df08558
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
BBuf Apr 7, 2022
b7f4e5a
fix gather kernel check bug
BBuf Apr 7, 2022
5efac3c
Update dim_gather_kernel_util.h
BBuf Apr 8, 2022
0680e85
fix comment
BBuf Apr 8, 2022
45395ef
fix comment
BBuf Apr 8, 2022
11ef1ab
Merge branch 'master' into fix_gather_kernel_check
BBuf Apr 9, 2022
0a2413a
auto format by CI
oneflow-ci-bot Apr 9, 2022
1aa679d
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 9, 2022
3c3c29b
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 9, 2022
e20a199
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 9, 2022
48a41dc
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 10, 2022
99697af
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 10, 2022
d534905
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 10, 2022
d64eb4c
fix bug
BBuf Apr 11, 2022
cda8aa4
revert
BBuf Apr 11, 2022
84db434
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 13, 2022
6d4447d
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 16, 2022
e4f3596
Merge branch 'master' into fix_gather_kernel_check
BBuf Apr 17, 2022
32ad956
Merge branch 'master' into fix_gather_kernel_check
mergify[bot] Apr 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions oneflow/user/kernels/dim_gather_kernel_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ template<typename IN_T, typename IDX_T>
struct DimGatherFunctor<DeviceType::kCPU, IN_T, IDX_T> final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index, input,
output);
int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,
IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim,
index, input, output);
}
};

Expand Down
12 changes: 6 additions & 6 deletions oneflow/user/kernels/dim_gather_kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,30 @@ namespace user_op {
template<typename IN_T, typename IDX_T>
__global__ void DoCUDADimGather(const DimOpIndexNdHelper<IDX_T> input_nd_helper,
const DimOpIndexNdHelper<IDX_T> index_nd_helper, int ndim,
int64_t elem_cnt, int32_t dim, const IDX_T* index,
int64_t elem_cnt, int32_t dim_length, int32_t dim, const IDX_T* index,
const IN_T* input, IN_T* output) {
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index, input,
DoDimGather<IN_T, IDX_T>(input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input,
output);
}

template<typename IDX_T, typename IN_T>
struct DimGatherFunctor<DeviceType::kCUDA, IN_T, IDX_T> final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
RUN_CUDA_KERNEL((DoCUDADimGather<IN_T, IDX_T>), stream, BlocksNum4ThreadsNum(elem_cnt),
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index, input, output);
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index, input, output);
}
};

// float16 special case of DimGatherFunctor template
template<typename IDX_T>
struct DimGatherFunctor<DeviceType::kCUDA, float16, IDX_T> final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt, int32_t dim_length,
int32_t dim, const IDX_T* index, const float16* input, float16* output) {
RUN_CUDA_KERNEL((DoCUDADimGather<half, IDX_T>), stream, BlocksNum4ThreadsNum(elem_cnt),
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim, index,
input_nd_helper, index_nd_helper, ndim, elem_cnt, dim_length, dim, index,
reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output));
}
};
Expand Down
13 changes: 10 additions & 3 deletions oneflow/user/kernels/dim_gather_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ template<DeviceType device_type, typename IN_T, typename IDX_T>
struct DimGatherFunctor final {
void operator()(ep::Stream* stream, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output);
int32_t dim_length, int32_t dim, const IDX_T* index, const IN_T* input,
IN_T* output);
};

template<typename IN_T, typename IDX_T>
OF_DEVICE_FUNC void DoDimGather(const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& index_nd_helper, int ndim,
int64_t elem_cnt, int32_t dim, const IDX_T* index,
const IN_T* input, IN_T* output) {
int64_t elem_cnt, int32_t dim_length, int32_t dim,
const IDX_T* index, const IN_T* input, IN_T* output) {
XPU_1D_KERNEL_LOOP(index_offset, elem_cnt) {
IDX_T coordinate[kDimGatherMaxDimCount] = {0};
const IDX_T x = index[index_offset];
#ifdef __CUDA_ARCH__
BBuf marked this conversation as resolved.
Show resolved Hide resolved
assert(x < dim_length && "gather index is out of bounds");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(x < dim_length && "gather index is out of bounds");
assert(x < dim_length); // gather index is out of bounds

是不是用注释就可以了,这个字符串的作业也只是注释作用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像不太一样,如果触发了这个检查会直接在命令行上抛出这个错误。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会吧,assert应该是没有打印的能力的。

这里的字符串只是一个表达式,表达式的值是一个有效的指针地址(总不为0)。所以assert到底值是多少,完全取决于dim_length,所以这个字符串我觉得就只是注释用了。

另外我搜了下代码里,.cu 文件里的assert都没这样加字符串的

#else
CHECK_LE(x, dim_length) << "RuntimeError: index " << x << " is out of bounds for dimension "
<< dim << " with size " << dim_length;
#endif
index_nd_helper.OffsetToNdIndex(index_offset, coordinate, ndim);
coordinate[dim] = x;

Expand Down
6 changes: 3 additions & 3 deletions oneflow/user/kernels/dim_gather_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class DimGatherKernel final : public user_op::OpKernel {
shape2dims(index_tensor->shape());
DimOpIndexNdHelper<IDX_T> index_nd_helper(shape_vec.data(), ndim);

DimGatherFunctor<device_type, IN_T, IDX_T>()(ctx->stream(), input_nd_helper, index_nd_helper,
ndim, index_tensor->shape().elem_cnt(), dim, index,
input, output);
DimGatherFunctor<device_type, IN_T, IDX_T>()(
ctx->stream(), input_nd_helper, index_nd_helper, ndim, index_tensor->shape().elem_cnt(),
index_tensor->shape().At(dim), dim, index, input, output);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
Expand Down