-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay, TOPI] Add searchsorted op #9184
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
1622025
Add relay definition
masahi a43e0c8
1D cpu test working
masahi 36cb3be
multi dim working
masahi 8f1f010
gpu version working
masahi 445852c
check shape in type rel
masahi 6b1adca
support side
masahi 98b88fc
use target specfic max threads
masahi 5ca105c
add relay boilerplate
masahi a055dd3
relay test working
masahi 2584f45
cleanup topi test
masahi 686e222
fix test
masahi 1c7a0ff
add torch converter
masahi a57b081
handle other cases
masahi cda4957
more topi test
masahi 16ef469
support torch bucketize
masahi 6fe38ad
update doc
masahi ce02cef
fix tests
masahi fe01efe
fix lint
masahi 3b18a32
rebase fix
masahi bac6dc5
make the test case smaller
masahi 5eb0c15
add tests for edge cases
masahi 5fc1bbb
replace "side" attribute with boolean "right"
masahi 4775b72
add more descrition to binear_search IR gen params
masahi c3eace8
return index from binary_search rather than update inplace
masahi 169088c
remove unused argument
masahi 431db6b
format fix
masahi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,3 +59,4 @@ | |
from .sparse_reshape import * | ||
from .transform import * | ||
from .unique import * | ||
from .searchsorted import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name | ||
"""searchsorted operator for GPU""" | ||
import tvm | ||
from tvm import te | ||
from .. import utils | ||
from ..searchsorted import binary_search | ||
|
||
|
||
def searchsorted(sorted_sequence, values, right, out_dtype="int64"): | ||
"""Find indices where elements should be inserted to maintain order. | ||
If `sorted_sequence` is N-dimensional, the innermost dimension of | ||
`values` are searched in the corresponding dimension of `sorted_sequence`. | ||
|
||
Parameters | ||
---------- | ||
sorted_sequence : te.Tensor | ||
N-D or 1-D Tensor, containing monotonically increasing sequence | ||
on the innermost dimension. | ||
|
||
values : te.Tensor | ||
N-D Tensor containing the search values. When `sorted_sequence` is 1-D, | ||
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` | ||
and `values` must be the same, and outer N-1 axes must have the same size. | ||
|
||
right : bool, optional | ||
Controls which index is returned if a value lands exactly on one of sorted values. If | ||
False, the index of the first suitable location found is given. If true, return the | ||
last such index. If there is no suitable index, return either 0 or N (where N is the | ||
size of the innermost dimension). | ||
|
||
dtype : string, optional | ||
The data type of the output indices. | ||
|
||
Returns | ||
------- | ||
indices : te.Tensor | ||
Tensor with same shape as values, representing the indices of | ||
elements of `values` if they are inserted in `sorted_sequence`. | ||
""" | ||
|
||
def ir(sorted_sequence, values, indices): | ||
ib = tvm.tir.ir_builder.create() | ||
sorted_sequence_shape = sorted_sequence.shape | ||
values_shape = values.shape | ||
num_search = utils.prod(values_shape) | ||
search_range = sorted_sequence_shape[-1] | ||
|
||
sorted_sequence = ib.buffer_ptr(sorted_sequence) | ||
values = ib.buffer_ptr(values) | ||
indices = ib.buffer_ptr(indices) | ||
|
||
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) | ||
bx = te.thread_axis("blockIdx.x") | ||
tx = te.thread_axis("threadIdx.x") | ||
ib.scope_attr( | ||
bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads) | ||
) | ||
ib.scope_attr(tx, "thread_extent", max_threads) | ||
tid = bx * max_threads + tx | ||
|
||
with ib.if_scope(tid < num_search): | ||
if len(sorted_sequence_shape) == 1: | ||
sequence_offset = 0 | ||
else: | ||
sequence_id = tid // values_shape[-1] | ||
sequence_offset = sequence_id * search_range | ||
|
||
indices[tid] = binary_search( | ||
ib, | ||
sequence_offset, | ||
search_range, | ||
sorted_sequence, | ||
values[tid], | ||
right, | ||
out_dtype, | ||
) | ||
|
||
return ib.get() | ||
|
||
return te.extern( | ||
values.shape, | ||
[sorted_sequence, values], | ||
lambda ins, outs: ir(ins[0], ins[1], outs[0]), | ||
name="searchsorted", | ||
dtype=out_dtype, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, just curious if there is any convention on the dtype of indices, there is a lot of index code with dyn gather I believe has all the indices in Int(64). Int(64) might be a better default.
The other attributes in this file have
NullValue<DataType>()
as the default value which is interesting.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, other ops have
NullValue<DataType>()
as the default here, but if we look at the python definition at https://github.com/apache/tvm/blob/main/python/tvm/relay/op/algorithm.py#L47, they say the default is int32. So I thought we should make that explicit inattrs/algorithm.h
as well.