Skip to content

Commit

Permalink
Add sort op to relay
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Dec 8, 2020
1 parent 0095b21 commit 49f0a48
Show file tree
Hide file tree
Showing 13 changed files with 436 additions and 5 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from .op import OpPattern, register_pattern
from .op import register_strategy

# sort
register_strategy("sort", strategy.sort_strategy)
register_pattern("sort", OpPattern.OPAQUE)

# argsort
register_strategy("argsort", strategy.argsort_strategy)
register_pattern("argsort", OpPattern.OPAQUE)
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,28 @@
from ..expr import TupleWrapper, Expr, Constant


def sort(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns data in sorted order.
Parameters
----------
data : relay.Expr
The input data tensor.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
Returns
-------
out : relay.Expr
Tensor with same shape as data.
"""
return _make.sort(data, axis, is_ascend)


def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tvm.te.hybrid import script
from tvm.runtime import convert
from tvm import topi

from .. import strategy
from .. import op as _reg
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,26 @@ def scatter_nd_cuda(attrs, inputs, out_type, target):
name="scatter_nd.cuda",
plevel=10,
)


@sort_strategy.register(["cuda", "gpu"])
def sort_strategy_cuda(attrs, inputs, out_type, target):
"""sort cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort),
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort.cuda",
)
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.sort", allow_missing=True
):
strategy.add_implementation(
wrap_compute_sort(topi.cuda.sort_thrust),
wrap_topi_schedule(topi.cuda.schedule_sort),
name="sort_thrust.cuda",
plevel=15,
)
return strategy


Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,30 @@ def schedule_sparse_transpose(attrs, outs, target):
return topi.generic.schedule_sparse_transpose(outs)


# sort
def wrap_compute_sort(topi_compute):
"""Wrap sort topi compute"""

def _compute_sort(attrs, inputs, _):
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend)]

return _compute_sort


@override_native_generic_func("sort_strategy")
def sort_strategy(attrs, inputs, out_type, target):
"""sort generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sort(topi.sort),
wrap_topi_schedule(topi.generic.schedule_sort),
name="sort.generic",
)
return strategy


# argsort
def wrap_compute_argsort(topi_compute):
"""Wrap argsort topi compute"""
Expand Down
100 changes: 100 additions & 0 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,89 @@ def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32")
return out[1]


def sort(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
dtype = "float32"
value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
out_buffers=[value_buf, indices_buf],
name="sort_gpu",
tag="sort_gpu",
)[0]
return out


def sort_thrust(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data: tvm.te.Tensor
The input array.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
dtype = "float32"

ndim = len(data.shape)
axis = ndim + axis if axis < 0 else axis

if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = te.extern(
[data.shape, data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend
),
out_buffers=[value_buf, indices_buf],
name="sort_gpu",
tag="sort_gpu",
)[0]

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = transpose(out, axes)
return out


def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand Down Expand Up @@ -407,6 +490,23 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
return out


def schedule_sort(outs):
"""Schedule for sort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_sort(outs)


def schedule_argsort(outs):
"""Schedule for argsort operator.
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
from .default import default_schedule as _default_schedule


def schedule_sort(outs):
"""Schedule for sort operator.
Parameters
----------
outs: Array of Tensor
The indices that would sort an input array along
the given axis.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_argsort(outs):
"""Schedule for argsort operator.
Expand Down
44 changes: 43 additions & 1 deletion python/tvm/topi/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,55 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments, unused-argument
"""Argsort operator"""
import tvm
from tvm import te
from .utils import get_const_tuple


def sort(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array
in sorted order.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
axis : int, optional
Axis along which to sort the input tensor.
By default the flattened array is used.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : tvm.te.Tensor
Sorted index tensor.
"""
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
out = te.extern(
data.shape,
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.sort.sort", ins[0], outs[0], axis, is_ascend
),
dtype=data.dtype,
in_buffers=[data_buf],
out_buffers=out_buf,
name="sort_cpu",
tag="sort_cpu",
)
return out


def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index
Expand Down
65 changes: 65 additions & 0 deletions src/relay/op/algorithm/sort.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file sort.cc
* \brief Sort operators
*/
#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/op.h>

namespace tvm {
namespace relay {

bool SortRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "Sort: expect input type to be TensorType but get " << types[0];
return false;
}
reporter->Assign(types[1], TensorType(data->shape, data->dtype));
return true;
}

Expr MakeSort(Expr data, int axis, bool is_ascend) {
auto attrs = make_object<ArgsortAttrs>();
attrs->axis = axis;
attrs->is_ascend = is_ascend;
static const Op& op = Op::Get("sort");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.sort").set_body_typed(MakeSort);

RELAY_REGISTER_OP("sort")
.describe(R"doc(Returns the indices that would sort an
input array along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type<ArgsortAttrs>()
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("Sort", SortRel);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 49f0a48

Please sign in to comment.