diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 19162a1083955..dd1a65288955e 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -84,3 +84,28 @@ def topk_shape_func(attrs, inputs, _): ret = [indices_out] return ret + + +@script +def _searchsorted_shape(sorted_sequence_shape, values_shape): + out_shape = output_tensor((values_shape.shape[0],), "int64") + if sorted_sequence_shape.shape[0] > 1: + assert ( + sorted_sequence_shape.shape[0] == values_shape.shape[0] + ), "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is not 1-D." + for i in range(values_shape.shape[0]): + if sorted_sequence_shape.shape[0] > 1 and i < values_shape.shape[0] - 1: + assert ( + sorted_sequence_shape[i] == values_shape[i] + ), "`sorted_sequence and `values` do not have the same shape along outer axes." + + out_shape[i] = values_shape[i] + return out_shape + + +@_reg.register_shape_func("searchsorted", False) +def searchsorted_shape_func(attrs, inputs, _): + """ + Shape func for searchsorted operator. + """ + return [_searchsorted_shape(inputs[0], inputs[1])] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8788faf45866e..f42f7ad7ca697 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -23,6 +23,7 @@ from tvm import relay, te from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +from tvm.topi.testing import searchsorted_ref from utils import ref_funcs from utils.assert_diagnostic import DiagnosticTesting @@ -2086,5 +2087,35 @@ def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, ax verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted( + sorted_sequence_shape, values_shape, sorted_sequence_shape_np, values_shape_np + ): + x = relay.var("x", relay.TensorType(sorted_sequence_shape, "float32")) + y = relay.var("y", relay.TensorType(values_shape, "float32")) + z = relay.searchsorted(x, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + x_np = np.sort(np.random.uniform(size=sorted_sequence_shape_np).astype("float32"), axis=-1) + y_np = np.random.uniform(size=values_shape_np).astype("float32") + + ref_res = searchsorted_ref(x_np, y_np, False, "int32") + check_result([x_np, y_np], mod, [ref_res]) + + for shape_np, values_shape_np in zip([(8, 9, 10), (10,), (11,)], [(8, 9, 20), (5,), (8, 9, 7)]): + sorted_sequence_shape = (relay.Any(),) * len(shape_np) + values_shape = (relay.Any(),) * len(values_shape_np) + + verify_searchsorted( + sorted_sequence_shape, + values_shape, + shape_np, + values_shape_np, + ) + + if __name__ == "__main__": pytest.main([__file__])