Skip to content

Commit

Permalink
Support dynamic shape searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Oct 21, 2021
1 parent d11bdcd commit 888a0fe
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,28 @@ def topk_shape_func(attrs, inputs, _):
ret = [indices_out]

return ret


@script
def _searchsorted_shape(sorted_sequence_shape, values_shape):
out_shape = output_tensor((values_shape.shape[0],), "int64")
if sorted_sequence_shape.shape[0] > 1:
assert (
sorted_sequence_shape.shape[0] == values_shape.shape[0]
), "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is not 1-D."
for i in range(values_shape.shape[0]):
if sorted_sequence_shape.shape[0] > 1 and i < values_shape.shape[0] - 1:
assert (
sorted_sequence_shape[i] == values_shape[i]
), "`sorted_sequence and `values` do not have the same shape along outer axes."

out_shape[i] = values_shape[i]
return out_shape


@_reg.register_shape_func("searchsorted", False)
def searchsorted_shape_func(attrs, inputs, _):
"""
Shape func for searchsorted operator.
"""
return [_searchsorted_shape(inputs[0], inputs[1])]
31 changes: 31 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import relay, te
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type
from tvm.topi.testing import searchsorted_ref

from utils import ref_funcs
from utils.assert_diagnostic import DiagnosticTesting
Expand Down Expand Up @@ -2086,5 +2087,35 @@ def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, ax
verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0)


@tvm.testing.uses_gpu
def test_searchsorted():
def verify_searchsorted(
sorted_sequence_shape, values_shape, sorted_sequence_shape_np, values_shape_np
):
x = relay.var("x", relay.TensorType(sorted_sequence_shape, "float32"))
y = relay.var("y", relay.TensorType(values_shape, "float32"))
z = relay.searchsorted(x, y)

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)

x_np = np.sort(np.random.uniform(size=sorted_sequence_shape_np).astype("float32"), axis=-1)
y_np = np.random.uniform(size=values_shape_np).astype("float32")

ref_res = searchsorted_ref(x_np, y_np, False, "int32")
check_result([x_np, y_np], mod, [ref_res])

for shape_np, values_shape_np in zip([(8, 9, 10), (10,), (11,)], [(8, 9, 20), (5,), (8, 9, 7)]):
sorted_sequence_shape = (relay.Any(),) * len(shape_np)
values_shape = (relay.Any(),) * len(values_shape_np)

verify_searchsorted(
sorted_sequence_shape,
values_shape,
shape_np,
values_shape_np,
)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 888a0fe

Please sign in to comment.