-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Topi] Fix GPU Dynamic Topk by Improving Dynamic Strided Slice in Topi #7018
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) This is awesome, I didn't think to do indexdiv(end[i] - begin[i], strides[i])
, how did you find the issue?
A few nitpicks below for code readability.
Can you also enable the test here?
tvm/tests/python/relay/dyn/test_dynamic_op_level6.py
Lines 25 to 27 in e212f96
# TODO(mbrookhart): Enable when we can get it working | |
# @tvm.testing.uses_gpu | |
def test_dynamic_topk(): |
Out of curiosity, why no nvptx? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Need to fix the CI
@mbrookhart Generally we need thrust for this dynamic sorting ops. nvptx will have issue to compile them. |
I don't love making thrust a necessary component unless we automatically enable it when we turn on cuda? If we don't support the tir-based sort, should we remove it from the codebase? |
I think we can raise an exception when compiling dynamic topk but Thrust is not enabled. Building with Thrust usually needs extra effort since it requires cmake >=3.13. User can enable it when necessary. For tvm cuda sorting, I'm not sure whether it covers some cases which Thrust doesn't. Maybe we can keep it a while. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get around the unit test error without forcing users to enable thrust? Just requesting changes while we chat about it, will reapprove once we decide.
if k > 0: | ||
if not isinstance(k, int) or k > 0: | ||
beg = [0] * ndim | ||
end = data.shape[:-1] + [k] | ||
out = [strided_slice(o, beg, end) for o in out] | ||
end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")] | ||
strides = [1] * ndim | ||
out = [strided_slice(o, beg, end, strides) for o in out] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kevinthesun, why don't we just repeat this change in the tir topk above? that would fix the unit test, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified cuda topk so that topk in dyn can pass. However, topk in test any in which data has dynamic shape can't pass without Thrust. I disable that test for now.
I think without thrust, we then have to fix sort. We can probably disable the test for now and come back to work on sorting and then enable the test. This would at least unblock downstream users to run models through thrust. @mbrookhart @icemelon9 @kevinthesun how do you think? |
I'm not really sure what's wrong with the tir sort, do we have a regression test/issue we could track? |
AFAIK cuda sort has several issues:
There is no clear path to a solution to these problems. For now the best way is to let user turn on Thrust, when they want to compile sort related op on nvidia gpu. |
Yeah, the perf of the kernel isn't great, and I see some thread definition issues that will cause issues with dynamic shapes. Do we have a flaky test we can include? I don't think it's important for this PR, but it might be interesting to tackle later. |
@mbrookhart yeah, argwhere is flaky on large inputs if sort is used |
:/ OddEvenTransportSort should be stable, but something looks very wrong about the threading in this kernel. I'll see if I can edit to to solve these problems at some point in the near-ish future. If somehow this sort isn't stable, that would easily explain flakiness in argwhere/argsort. |
Thanks @kevinthesun @mbrookhart @icemelon9 |
apache#7018) * Fix GPU dynamic Topk * Fix style * Minor fix * Simplfy dynamic checking * Fix lint * More improvements * Disable test any topk
apache#7018) * Fix GPU dynamic Topk * Fix style * Minor fix * Simplfy dynamic checking * Fix lint * More improvements * Disable test any topk
apache#7018) * Fix GPU dynamic Topk * Fix style * Minor fix * Simplfy dynamic checking * Fix lint * More improvements * Disable test any topk
This fix also works for gpu argwhere.
@zhiics @anijain2305 @mbrookhart @Laurawly