-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Topi]Add Sort Op to Relay #6978
Conversation
cc @kevinthesun, PTAL when you have time |
For dynamic topk, can we use Thrust? IMO after sorting with Thrust the only work left is dynamic strided_slice in topi. Not quite sure whether we need to do this in relay level. |
I tried this PR, I see the following error when used with dynamic input shape
|
python/tvm/topi/cuda/sort.py
Outdated
@@ -593,3 +694,82 @@ def schedule_topk(outs): | |||
The computation schedule for the op. | |||
""" | |||
return _schedule_sort(outs) | |||
|
|||
|
|||
def _dyn_topk_legalize(attrs, inputs, arg_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we improve gpu dynamic strided_slice and it helps both dynamic topk and argwhere. For sorting, we can still rely on existing thrust routine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any tips for debugging generated code? I have a branch where I've done this a a topi composition, it passes if I compile with AddressSanitizer, but segfaults with a normal build in the generated code. @zhiics attempted to do argwhere, but gets compilation errors with topi dropping variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also get segfault using current topi dynamic strided slice to cut topk result. I'll spend some time on it.
@anijain2305 I'll add a shape func for sort and a regression test @kevinthesun I added a thrust implementation of sort, and tested dynamic topk with thrust enabled, it worked just fine. |
#7018 |
closing in favor of #7018 |
@kevinthesun found a fix for dynamic topk, so I don't think I need this legalization anymore, but do you think the sort op at the relay level is still useful? |
yes hummingbird people have asked for a converter for pytorch sort op, that would be exactly what I need. Thanks. cc @interesaaat |
Okay, I'll reopen and rip out the legalization. |
@kevinthesun Can you re-review without the topk legalization? |
lambda ins, outs: tvm.tir.call_packed( | ||
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend | ||
), | ||
out_buffers=[value_buf, indices_buf], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to pass indices_buf
for normal sort, but tvm.contrib.thrust.sort
always expects indices_buf
to be passed in... What we call tvm.contrib.thrust.sort
is really argsort.
For optimal performance we should have the true tvm.contrib.thrust.sort
someday.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, you're right, we don't strictly need that. Want me to add a todo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix remaining doc issues, otherwise LGTM.
Thanks @mbrookhart @zhiics @kevinthesun @jwfromm |
* Add sort op to relay * fix lint * fix sort docstring * fix docs * add TODO, shape_func, cleanup * add dynamic tests for sort and argsort
* Add sort op to relay * fix lint * fix sort docstring * fix docs * add TODO, shape_func, cleanup * add dynamic tests for sort and argsort
* Add sort op to relay * fix lint * fix sort docstring * fix docs * add TODO, shape_func, cleanup * add dynamic tests for sort and argsort
* Add sort op to relay * fix lint * fix sort docstring * fix docs * add TODO, shape_func, cleanup * add dynamic tests for sort and argsort
Parameters | ||
---------- | ||
outs: Array of Tensor | ||
The computation graph description of argsort |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: argsort -> sort
Re purposing PR for Relay Sort