-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gather kernel check #7979
Fix gather kernel check #7979
Conversation
XPU_1D_KERNEL_LOOP(index_offset, elem_cnt) { | ||
IDX_T coordinate[kDimGatherMaxDimCount] = {0}; | ||
const IDX_T x = index[index_offset]; | ||
#ifdef WITH_CUDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
问了下juncheng,这里应该是 __CUDA_ARCH__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的区别是什么?
XPU_1D_KERNEL_LOOP(index_offset, elem_cnt) { | ||
IDX_T coordinate[kDimGatherMaxDimCount] = {0}; | ||
const IDX_T x = index[index_offset]; | ||
#ifdef __CUDA_ARCH__ | ||
assert(x < dim_length && "gather index is out of bounds"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(x < dim_length && "gather index is out of bounds"); | |
assert(x < dim_length); // gather index is out of bounds |
是不是用注释就可以了,这个字符串的作业也只是注释作用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好像不太一样,如果触发了这个检查会直接在命令行上抛出这个错误。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会吧,assert应该是没有打印的能力的。
这里的字符串只是一个表达式,表达式的值是一个有效的指针地址(总不为0)。所以assert到底值是多少,完全取决于dim_length,所以这个字符串我觉得就只是注释用了。
另外我搜了下代码里,.cu 文件里的assert都没这样加字符串的
@@ -24,30 +24,30 @@ namespace user_op { | |||
template<typename IN_T, typename IDX_T> | |||
__global__ void DoCUDADimGather(const DimOpIndexNdHelper<IDX_T> input_nd_helper, | |||
const DimOpIndexNdHelper<IDX_T> index_nd_helper, int ndim, | |||
int64_t elem_cnt, int32_t dim, const IDX_T* index, | |||
int64_t elem_cnt, int64_t dim_length, int32_t dim, const IDX_T* index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int64_t elem_cnt, int64_t dim_length, int32_t dim, const IDX_T* index, | |
int64_t elem_cnt, int64_t dim_length, int32_t dim, const IDX_T* index, |
int64_t dim_length 如果用 int32_t
也够的话,最好用 int32_t
,给 global kernel 的传参瘦身是很重要的。
不过我不确定 int32_t
是否够用,你更能准确判断吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我改回int32_t
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7979/ |
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7979/ |
CI failed when running job: cuda-module. PR label automerge has been removed |
CI failed when running job: cuda-benchmark. PR label automerge has been removed |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7979/ |
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7979/ |
给gather op加了index check,使它的值限制在指定维度的长度中,对齐报错信息。