Skip to content

Commit

Permalink
replace "side" attribute with boolean "right"
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Oct 12, 2021
1 parent 5eb0c15 commit 5fc1bbb
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 52 deletions.
9 changes: 6 additions & 3 deletions include/tvm/relay/attrs/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
};

struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> {
std::string side;
bool right;
DataType dtype;

TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") {
TVM_ATTR_FIELD(side).set_default("left").describe(
"Controls which index is returned if a value lands exactly on one of sorted values.");
TVM_ATTR_FIELD(right).set_default(false).describe(
"Controls which index is returned if a value lands exactly on one of sorted values. If "
" false, the index of the first suitable location found is given. If true, return the "
"last such index. If there is no suitable index, return either 0 or N (where N is the "
"size of the innermost dimension).");
TVM_ATTR_FIELD(dtype)
.set_default(DataType::Int(32))
.describe("Data type of the output indices.");
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2776,13 +2776,12 @@ def all_any_common(self, op, inputs, input_types):

def searchsorted_common(self, sorted_sequence, values, out_int32, right):
dtype = "int32" if out_int32 else "int64"
side = "right" if right else "left"
values_shape = _infer_shape(values)

if len(values_shape) == 0:
values = _op.expand_dims(values, 0)

out = _op.searchsorted(sorted_sequence, values, side=side, dtype=dtype)
out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype)

if len(values_shape) == 0:
return _op.squeeze(out)
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/relay/op/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
return out


def searchsorted(sorted_sequence, values, side="left", dtype="int32"):
def searchsorted(sorted_sequence, values, right=False, dtype="int32"):
"""Find indices where elements should be inserted to maintain order.
If `sorted_sequence` is N-dimensional, the innermost dimension of
`values` are searched in the corresponding dimension of `sorted_sequence`.
Expand All @@ -133,10 +133,11 @@ def searchsorted(sorted_sequence, values, side="left", dtype="int32"):
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
and `values` must be the same, and outer N-1 axes must have the same size.
side : string, optional
It can be `left` or `right`. If `left`, gets the lower bound index for each value
in `values` on the corresponding innermost dimension of the `sorted_sequence`.
If `right`, gets the upper bound index instead.
right : bool, optional
Controls which index is returned if a value lands exactly on one of sorted values. If
False, the index of the first suitable location found is given. If true, return the
last such index. If there is no suitable index, return either 0 or N (where N is the
size of the innermost dimension).
dtype : string, optional
The data type of the output indices.
Expand All @@ -147,4 +148,4 @@ def searchsorted(sorted_sequence, values, side="left", dtype="int32"):
Tensor with same shape as values, representing the indices of
elements of `values` if they are inserted in `sorted_sequence`.
"""
return _make.searchsorted(sorted_sequence, values, side, dtype)
return _make.searchsorted(sorted_sequence, values, right, dtype)
4 changes: 2 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,9 +1007,9 @@ def wrap_compute_searchsorted(topi_compute):
"""Wrap searchsorted compute"""

def _compute_searchsorted(attrs, inputs, out_type):
side = attrs.side
right = attrs.right
dtype = attrs.dtype
return [topi_compute(inputs[0], inputs[1], side, dtype)]
return [topi_compute(inputs[0], inputs[1], right, dtype)]

return _compute_searchsorted

Expand Down
13 changes: 7 additions & 6 deletions python/tvm/topi/cuda/searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..searchsorted import binary_search


def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"):
def searchsorted(sorted_sequence, values, right, out_dtype="int64"):
"""Find indices where elements should be inserted to maintain order.
If `sorted_sequence` is N-dimensional, the innermost dimension of
`values` are searched in the corresponding dimension of `sorted_sequence`.
Expand All @@ -38,10 +38,11 @@ def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"):
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
and `values` must be the same, and outer N-1 axes must have the same size.
side : string, optional
It can be `left` or `right`. If `left`, gets the lower bound index for each value
in `values` on the corresponding innermost dimension of the `sorted_sequence`.
If `right`, gets the upper bound index instead.
right : bool, optional
Controls which index is returned if a value lands exactly on one of sorted values. If
False, the index of the first suitable location found is given. If true, return the
last such index. If there is no suitable index, return either 0 or N (where N is the
size of the innermost dimension).
dtype : string, optional
The data type of the output indices.
Expand Down Expand Up @@ -88,7 +89,7 @@ def ir(sorted_sequence, values, indices):
sorted_sequence,
values,
indices,
side,
right,
out_dtype,
)

Expand Down
21 changes: 11 additions & 10 deletions python/tvm/topi/searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def binary_search(
ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype
ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, right, out_dtype
):
"""Common IR generator for CPU and GPU searchsorted."""
lo = ib.allocate(out_dtype, (1,), name="lo", scope="local")
Expand All @@ -35,9 +35,9 @@ def binary_search(

# Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu
def condition(current_val, target_val):
if side == "left":
return current_val < target_val
return current_val <= target_val
if right:
return current_val <= target_val
return current_val < target_val

with ib.while_loop(lo[0] < hi[0]):
mid = lo[0] + (hi[0] - lo[0] >> 1)
Expand All @@ -49,7 +49,7 @@ def condition(current_val, target_val):
out_indices[index] = lo[0]


def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"):
def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"):
"""Find indices where elements should be inserted to maintain order.
If `sorted_sequence` is N-dimensional, the innermost dimension of
`values` are searched in the corresponding dimension of `sorted_sequence`.
Expand All @@ -65,10 +65,11 @@ def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"):
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
and `values` must be the same, and outer N-1 axes must have the same size.
side : string, optional
It can be `left` or `right`. If `left`, gets the lower bound index for each value
in `values` on the corresponding innermost dimension of the `sorted_sequence`.
If `right`, gets the upper bound index instead.
right : bool, optional
Controls which index is returned if a value lands exactly on one of sorted values. If
False, the index of the first suitable location found is given. If true, return the
last such index. If there is no suitable index, return either 0 or N (where N is the
size of the innermost dimension).
dtype : string, optional
The data type of the output indices.
Expand Down Expand Up @@ -106,7 +107,7 @@ def ir(sorted_sequence, values, indices):
sorted_sequence,
values,
indices,
side,
right,
out_dtype,
)

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/testing/searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import numpy as np


def searchsorted_ref(sorted_sequence, values, side, out_dtype):
def searchsorted_ref(sorted_sequence, values, right, out_dtype):
"""Run Numpy searchsorted on 1-D or N-D sorted_sequence."""
side = "right" if right else "left"
if len(sorted_sequence.shape) == 1 and len(values.shape) > 1:
sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1))
else:
Expand Down
6 changes: 2 additions & 4 deletions src/relay/op/algorithm/searchsorted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ bool SearchSortedRel(const Array<Type>& types, int num_inputs, const Attrs& attr
ICHECK(sorted_sequence) << "Expects TensorType in the first input";
ICHECK(values) << "Expects TensorType in the second input";
ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one";
ICHECK(param->side == "left" || param->side == "right")
<< "`side` parameter must be either `left` or `right`";

if (sorted_sequence->shape.size() > 1) {
ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size())
Expand All @@ -60,11 +58,11 @@ bool SearchSortedRel(const Array<Type>& types, int num_inputs, const Attrs& attr
return true;
}

Expr MakeSearchSorted(Expr sorted_sequence, Expr values, String side, DataType dtype) {
Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) {
auto attrs = make_object<SearchSortedAttrs>();
static const Op& op = Op::Get("searchsorted");
attrs->dtype = dtype;
attrs->side = side;
attrs->right = right;
return Call(op, {sorted_sequence, values}, Attrs(attrs), {});
}

Expand Down
10 changes: 5 additions & 5 deletions tests/python/relay/test_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,25 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"):

@tvm.testing.uses_gpu
def test_searchsorted():
def verify_searchsorted(side, dtype):
def verify_searchsorted(right, dtype):
shape = (8, 9, 10)
values_shape = shape[:-1] + (10,)
sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32"))
values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32"))
out = relay.searchsorted(sorted_sequence, values, side, dtype)
out = relay.searchsorted(sorted_sequence, values, right, dtype)
func = relay.Function([sorted_sequence, values], out)
sorted_sequence_np = np.sort(np.random.randn(*shape).astype("float32"), axis=-1)
values_np = np.random.randn(*values_shape).astype("float32")
np_indices = searchsorted_ref(sorted_sequence_np, values_np, side, dtype)
np_indices = searchsorted_ref(sorted_sequence_np, values_np, right, dtype)

for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
sorted_sequence_np, values_np
)
np.testing.assert_equal(op_res.numpy(), np_indices)

verify_searchsorted("left", "int32")
verify_searchsorted("right", "int64")
verify_searchsorted(False, "int32")
verify_searchsorted(True, "int64")


if __name__ == "__main__":
Expand Down
26 changes: 13 additions & 13 deletions tests/python/topi/python/test_topi_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def get_implementations():

@tvm.testing.parametrize_targets
def test_searchsorted(dev, target):
def verify_with_input(sorted_sequence_np, values_np, side):
def verify_with_input(sorted_sequence_np, values_np, right):
sorted_sequence = te.placeholder(sorted_sequence_np.shape, dtype="float32")
values = te.placeholder(values_np.shape, dtype="float32")
out_dtype = "int32"
implementations = get_implementations()
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)

with tvm.target.Target(target):
indices = fcompute(sorted_sequence, values, side, out_dtype)
indices = fcompute(sorted_sequence, values, right, out_dtype)
s = fschedule([indices])

func = tvm.build(s, [sorted_sequence, values, indices], target=target)
Expand All @@ -64,10 +64,10 @@ def verify_with_input(sorted_sequence_np, values_np, side):
b = tvm.nd.array(values_np, dev)
c = tvm.nd.array(np.zeros(values_np.shape, dtype=indices.dtype), dev)
func(a, b, c)
ref = searchsorted_ref(sorted_sequence_np, values_np, side, out_dtype)
ref = searchsorted_ref(sorted_sequence_np, values_np, right, out_dtype)
np.testing.assert_equal(c.numpy(), ref)

def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False):
def verify(sequence_len, num_search, outer_axes, right, sorted_sequence_1d=False):
if sorted_sequence_1d:
sorted_sequence_shape = (sequence_len,)
else:
Expand All @@ -77,17 +77,17 @@ def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False)
verify_with_input(
np.sort(np.random.randn(*sorted_sequence_shape).astype("float32"), axis=-1),
np.random.randn(*values_shape).astype("float32"),
side,
right,
)

verify(1024, 1000, (10, 5, 3), "left")
verify(999, 2000, (10, 5, 3), "right")
verify(1000, 1000, (), "left")
verify(2001, 100, (500,), "right")
verify(2001, 100, (500,), "left", sorted_sequence_1d=True)
verify(1024, 1000, (10, 5, 3), False)
verify(999, 2000, (10, 5, 3), True)
verify(1000, 1000, (), False)
verify(2001, 100, (500,), True)
verify(2001, 100, (500,), False, sorted_sequence_1d=True)

# Check edge cases
for side in ["left", "right"]:
for right in [True, False]:
sorted_sequence = np.array([1, 2, 3, 4, 5], dtype="float32")
verify_with_input(sorted_sequence, np.array([6], dtype="float32"), side)
verify_with_input(sorted_sequence, np.array([0], dtype="float32"), side)
verify_with_input(sorted_sequence, np.array([6], dtype="float32"), right)
verify_with_input(sorted_sequence, np.array([0], dtype="float32"), right)

0 comments on commit 5fc1bbb

Please sign in to comment.