Skip to content

Commit

Permalink
[Contrib] Support fp16 input in cpu sort (apache#8672)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and mehrdadh committed Aug 11, 2021
1 parent 8da0114 commit 84f7bbf
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 32 deletions.
71 changes: 60 additions & 11 deletions src/runtime/contrib/sort/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file Use standard C library call.
*/

#include <builtin_fp16.h>
#include <dlpack/dlpack.h>
#include <tvm/runtime/registry.h>

Expand All @@ -42,6 +43,24 @@ bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_
return lhs.second > rhs.second;
}

struct float16 {
uint16_t bits;
float to_float() const {
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits);
}
};

template <>
bool CompareAscend(const std::pair<int64_t, float16>& lhs, const std::pair<int64_t, float16>& rhs) {
return lhs.second.to_float() < rhs.second.to_float();
}

template <>
bool CompareDescend(const std::pair<int64_t, float16>& lhs,
const std::pair<int64_t, float16>& rhs) {
return lhs.second.to_float() > rhs.second.to_float();
}

// Argsort implemented C library sort for nms.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
Expand Down Expand Up @@ -125,7 +144,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV
});

template <typename DataType, typename OutType>
void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) {
void sort_impl(
DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
std::function<void(OutType*, size_t, const std::pair<int64_t, DataType>&)> epilogue) {
auto data_ptr = static_cast<DataType*>(input->data);
auto out_ptr = static_cast<OutType*>(output->data);
std::vector<std::pair<int64_t, DataType>> sorter;
Expand Down Expand Up @@ -153,27 +174,29 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
}
if (is_argsort) {
for (int64_t k = 0; k < input->shape[axis]; ++k) {
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first);
}
} else {
for (int64_t k = 0; k < input->shape[axis]; ++k) {
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].second);
}
for (int64_t k = 0; k < input->shape[axis]; ++k) {
epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]);
}
}
}
}

template <typename DataType, typename OutType>
void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
return sort_impl<DataType, OutType>(input, output, axis, is_ascend, true);
return sort_impl<DataType, OutType>(
input, output, axis, is_ascend,
[](OutType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
out_ptr[index] = static_cast<OutType>(sort_pair.first);
});
}

template <typename DataType>
void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
return sort_impl<DataType, DataType>(input, output, axis, is_ascend, false);
return sort_impl<DataType, DataType>(
input, output, axis, is_ascend,
[](DataType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
out_ptr[index] = sort_pair.second;
});
}

// Argsort implemented C library sort.
Expand Down Expand Up @@ -254,6 +277,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float16") {
if (out_dtype == "int32") {
argsort<float16, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<float16, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<float16, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<float16, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
Expand Down Expand Up @@ -295,6 +330,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal
sort<int32_t>(input, output, axis, is_ascend);
} else if (data_dtype == "int64") {
sort<int64_t>(input, output, axis, is_ascend);
} else if (data_dtype == "float16") {
sort<float16>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
Expand Down Expand Up @@ -432,6 +469,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float16") {
if (out_dtype == "int32") {
topk<float16, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<float16, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<float16, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<float16, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
Expand Down
42 changes: 22 additions & 20 deletions tests/python/relay/test_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,23 @@
# under the License.
""" Support level6 operator test cases.
"""
import pytest
import numpy as np
import tvm
from tvm import te
from tvm import relay
import tvm.testing


@tvm.testing.uses_gpu
def test_sort():
def verify_sort(shape, axis, is_ascend, is_dyn=False):

def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"):
if is_dyn:
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
else:
x = relay.var("x", relay.TensorType(shape, "float32"))
x = relay.var("x", relay.TensorType(shape, in_dtype))
z = relay.sort(x, axis=axis, is_ascend=is_ascend)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
x_data = np.random.uniform(size=shape).astype(in_dtype)
if is_ascend:
ref_res = np.sort(x_data, axis=axis)
else:
Expand All @@ -56,18 +55,19 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False):
verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn)
verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn)
verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn)
verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn, in_dtype="float16")


@tvm.testing.uses_gpu
def test_argsort():
def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False):
def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float32"):
if is_dyn:
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
else:
x = relay.var("x", relay.TensorType(shape, "float32"))
x = relay.var("x", relay.TensorType(shape, in_dtype))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
x_data = np.random.uniform(size=shape).astype(in_dtype)
if is_ascend:
ref_res = np.argsort(x_data, axis=axis, kind="stable")
else:
Expand All @@ -93,31 +93,34 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False):
verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort(
(1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn, in_dtype="float16"
)


@tvm.testing.uses_gpu
def test_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"):
shape = (20, 100)
x = relay.var("x", relay.TensorType(shape, "float32"))
x = relay.var("x", relay.TensorType(shape, in_dtype))
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
if isinstance(out, relay.expr.TupleWrapper):
out = out.astuple()
func = relay.Function([x], out)
np_data = np.random.uniform(size=shape).astype("float32")
np_data = np.random.uniform(size=shape).astype(in_dtype)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
np_indices = np.argsort(np_data, axis=axis, kind="stable")
else:
np_indices = np.argsort(-np_data, axis=axis)
np_indices = np.argsort(-np_data, axis=axis, kind="stable")
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype("float32")
np_values = np.zeros(np_indices.shape).astype(in_dtype)
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype("float32")
np_values = np.zeros(np_indices.shape).astype(in_dtype)
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
Expand All @@ -140,9 +143,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
for ret_type in ["both", "values", "indices"]:
verify_topk(k, axis, ret_type, True, "int64")
verify_topk(k, axis, ret_type, False, "float32")
verify_topk(k, axis, ret_type, False, "int64", "float16")


if __name__ == "__main__":
test_sort()
test_argsort()
test_topk()
pytest.main([__file__])
2 changes: 1 addition & 1 deletion web/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TVM_ROOT=$(shell cd ..; pwd)

INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\
-I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include
-I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -I$(TVM_ROOT)/3rdparty/compiler-rt

.PHONY: clean all rmtypedep preparetest

Expand Down

0 comments on commit 84f7bbf

Please sign in to comment.