Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Contrib] Support fp16 input in cpu sort #8672

Merged
merged 5 commits into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions src/runtime/contrib/sort/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file Use standard C library call.
*/

#include <builtin_fp16.h>
#include <dlpack/dlpack.h>
#include <tvm/runtime/registry.h>

Expand All @@ -42,6 +43,24 @@ bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_
return lhs.second > rhs.second;
}

struct float16 {
uint16_t bits;
float to_float() const {
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits);
}
};

template <>
bool CompareAscend(const std::pair<int64_t, float16>& lhs, const std::pair<int64_t, float16>& rhs) {
return lhs.second.to_float() < rhs.second.to_float();
}

template <>
bool CompareDescend(const std::pair<int64_t, float16>& lhs,
const std::pair<int64_t, float16>& rhs) {
return lhs.second.to_float() > rhs.second.to_float();
}

// Argsort implemented C library sort for nms.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
Expand Down Expand Up @@ -125,7 +144,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV
});

template <typename DataType, typename OutType>
void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) {
void sort_impl(
DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
std::function<void(OutType*, size_t, const std::pair<int64_t, DataType>&)> epilogue) {
auto data_ptr = static_cast<DataType*>(input->data);
auto out_ptr = static_cast<OutType*>(output->data);
std::vector<std::pair<int64_t, DataType>> sorter;
Expand Down Expand Up @@ -153,27 +174,29 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
}
if (is_argsort) {
for (int64_t k = 0; k < input->shape[axis]; ++k) {
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first);
}
} else {
for (int64_t k = 0; k < input->shape[axis]; ++k) {
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].second);
}
for (int64_t k = 0; k < input->shape[axis]; ++k) {
epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]);
masahi marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

template <typename DataType, typename OutType>
void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
return sort_impl<DataType, OutType>(input, output, axis, is_ascend, true);
return sort_impl<DataType, OutType>(
input, output, axis, is_ascend,
[](OutType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
out_ptr[index] = static_cast<OutType>(sort_pair.first);
});
}

template <typename DataType>
void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
return sort_impl<DataType, DataType>(input, output, axis, is_ascend, false);
return sort_impl<DataType, DataType>(
input, output, axis, is_ascend,
[](DataType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
out_ptr[index] = sort_pair.second;
});
}

// Argsort implemented C library sort.
Expand Down Expand Up @@ -254,6 +277,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float16") {
if (out_dtype == "int32") {
argsort<float16, int32_t>(input, output, axis, is_ascend);
} else if (out_dtype == "int64") {
argsort<float16, int64_t>(input, output, axis, is_ascend);
} else if (out_dtype == "float32") {
argsort<float16, float>(input, output, axis, is_ascend);
} else if (out_dtype == "float64") {
argsort<float16, double>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
Expand Down Expand Up @@ -295,6 +330,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal
sort<int32_t>(input, output, axis, is_ascend);
} else if (data_dtype == "int64") {
sort<int64_t>(input, output, axis, is_ascend);
} else if (data_dtype == "float16") {
sort<float16>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
Expand Down Expand Up @@ -432,6 +469,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float16") {
if (out_dtype == "int32") {
topk<float16, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "int64") {
topk<float16, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float32") {
topk<float16, float>(input, values_out, indices_out, k, axis, is_ascend);
} else if (out_dtype == "float64") {
topk<float16, double>(input, values_out, indices_out, k, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
Expand Down
36 changes: 20 additions & 16 deletions tests/python/relay/test_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@

@tvm.testing.uses_gpu
def test_sort():
def verify_sort(shape, axis, is_ascend, is_dyn=False):

def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also do a small refactoring for this file:

  • Use @pytest.mark.parametrize("in_dtype", ["float32", "float16"]) in each unit test.
  • Let pytest collect tests in the main function:
if __name__ == "__main__":
    pytest.main([__file__])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that make the testing time double? This file is fairly slow to test (more than 3 min according to https://ci.tlcpack.ai/job/tvm/job/main/1384/testReport/ctypes.tests.python.relay/test_op_level6/) and I don't think fp16 tests need to run as often as fp32.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that's a fair concern.

if is_dyn:
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
else:
x = relay.var("x", relay.TensorType(shape, "float32"))
x = relay.var("x", relay.TensorType(shape, in_dtype))
z = relay.sort(x, axis=axis, is_ascend=is_ascend)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
x_data = np.random.uniform(size=shape).astype(in_dtype)
if is_ascend:
ref_res = np.sort(x_data, axis=axis)
else:
Expand All @@ -56,18 +55,19 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False):
verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn)
verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn)
verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn)
verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn, in_dtype="float16")


@tvm.testing.uses_gpu
def test_argsort():
def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False):
def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float32"):
if is_dyn:
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
else:
x = relay.var("x", relay.TensorType(shape, "float32"))
x = relay.var("x", relay.TensorType(shape, in_dtype))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
x_data = np.random.uniform(size=shape).astype(in_dtype)
if is_ascend:
ref_res = np.argsort(x_data, axis=axis, kind="stable")
else:
Expand All @@ -93,31 +93,34 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False):
verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
verify_argsort(
(1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn, in_dtype="float16"
)


@tvm.testing.uses_gpu
def test_topk():
def verify_topk(k, axis, ret_type, is_ascend, dtype):
def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"):
shape = (20, 100)
x = relay.var("x", relay.TensorType(shape, "float32"))
x = relay.var("x", relay.TensorType(shape, in_dtype))
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
if isinstance(out, relay.expr.TupleWrapper):
out = out.astuple()
func = relay.Function([x], out)
np_data = np.random.uniform(size=shape).astype("float32")
np_data = np.random.uniform(size=shape).astype(in_dtype)
if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
np_indices = np.argsort(np_data, axis=axis, kind="stable")
else:
np_indices = np.argsort(-np_data, axis=axis)
np_indices = np.argsort(-np_data, axis=axis, kind="stable")
kk = k if k >= 1 else shape[axis]
if axis == 0:
np_indices = np_indices[:kk, :]
np_values = np.zeros(np_indices.shape).astype("float32")
np_values = np.zeros(np_indices.shape).astype(in_dtype)
for i in range(shape[1]):
np_values[:, i] = np_data[np_indices[:, i], i]
else:
np_indices = np_indices[:, :kk]
np_values = np.zeros(np_indices.shape).astype("float32")
np_values = np.zeros(np_indices.shape).astype(in_dtype)
for i in range(shape[0]):
np_values[i, :] = np_data[i, np_indices[i, :]]
np_indices = np_indices.astype(dtype)
Expand All @@ -140,6 +143,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
for ret_type in ["both", "values", "indices"]:
verify_topk(k, axis, ret_type, True, "int64")
verify_topk(k, axis, ret_type, False, "float32")
verify_topk(k, axis, ret_type, False, "int64", "float16")


if __name__ == "__main__":
Expand Down