From 8a04c2cc98b8d25a7b7ab648c8324bab4e7cceab Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 6 Aug 2021 17:59:08 +0900 Subject: [PATCH 1/5] fp16 topk test working --- src/runtime/contrib/sort/sort.cc | 27 +++++++++++++++++++++++++++ tests/python/relay/test_op_level6.py | 17 +++++++++-------- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 66f36ffa50d6..a8b6392c0bdf 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -21,6 +21,7 @@ * \file Use standard C library call. */ +#include #include #include @@ -42,6 +43,24 @@ bool CompareDescend(const std::pair& lhs, const std::pair rhs.second; } +struct float16 { + uint16_t bits; + float to_float() const { + return __extendXfYf2__(bits); + } +}; + +template <> +bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { + return lhs.second.to_float() < rhs.second.to_float(); +} + +template <> +bool CompareDescend(const std::pair& lhs, + const std::pair& rhs) { + return lhs.second.to_float() > rhs.second.to_float(); +} + // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. @@ -432,6 +451,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 1838233e3a3a..50d02012e7d9 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -85,7 +85,7 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): tvm.testing.assert_allclose(op_res.numpy(), ref_res.astype(dtype), rtol=1e-5) for is_dyn in [False, True]: - for dtype in ["int32", "int64", "float32", "float64"]: + for dtype in ["int32", "int64", "float32", "float64", "float16"]: verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn) dtype = "int32" @@ -97,27 +97,27 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): @tvm.testing.uses_gpu def test_topk(): - def verify_topk(k, axis, ret_type, is_ascend, dtype): + def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): shape = (20, 100) - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) out = relay.topk(x, k, axis, ret_type, is_ascend, dtype) if isinstance(out, relay.expr.TupleWrapper): out = out.astuple() func = relay.Function([x], out) - np_data = np.random.uniform(size=shape).astype("float32") + np_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: - np_indices = np.argsort(np_data, axis=axis) + np_indices = np.argsort(np_data, axis=axis, kind="stable") else: - np_indices = np.argsort(-np_data, axis=axis) + np_indices = np.argsort(-np_data, axis=axis, kind="stable") kk = k if k >= 1 else shape[axis] if axis == 0: np_indices = np_indices[:kk, :] - np_values = np.zeros(np_indices.shape).astype("float32") + np_values = np.zeros(np_indices.shape).astype(in_dtype) for i in range(shape[1]): np_values[:, i] = np_data[np_indices[:, i], i] else: np_indices = np_indices[:, :kk] - np_values = np.zeros(np_indices.shape).astype("float32") + np_values = np.zeros(np_indices.shape).astype(in_dtype) for i in range(shape[0]): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) @@ -140,6 +140,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): for ret_type in ["both", "values", "indices"]: verify_topk(k, axis, ret_type, True, "int64") verify_topk(k, axis, ret_type, False, "float32") + verify_topk(k, axis, ret_type, False, "int64", "float16") if __name__ == "__main__": From a134af39afa1d0ca49727a4f28e7f15d88494640 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 6 Aug 2021 18:15:45 +0900 Subject: [PATCH 2/5] sort working --- src/runtime/contrib/sort/sort.cc | 28 +++++++++++++++++----------- tests/python/relay/test_op_level6.py | 10 +++++----- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index a8b6392c0bdf..93f7aaa91b9a 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -144,7 +144,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV }); template -void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) { +void sort_impl( + DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, + std::function&)> epilogue) { auto data_ptr = static_cast(input->data); auto out_ptr = static_cast(output->data); std::vector> sorter; @@ -172,14 +174,8 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, } else { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } - if (is_argsort) { - for (int64_t k = 0; k < input->shape[axis]; ++k) { - out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].first); - } - } else { - for (int64_t k = 0; k < input->shape[axis]; ++k) { - out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].second); - } + for (int64_t k = 0; k < input->shape[axis]; ++k) { + epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]); } } } @@ -187,12 +183,20 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, template void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - return sort_impl(input, output, axis, is_ascend, true); + return sort_impl( + input, output, axis, is_ascend, + [](OutType* out_ptr, size_t index, const std::pair& sort_pair) { + out_ptr[index] = static_cast(sort_pair.first); + }); } template void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - return sort_impl(input, output, axis, is_ascend, false); + return sort_impl( + input, output, axis, is_ascend, + [](DataType* out_ptr, size_t index, const std::pair& sort_pair) { + out_ptr[index] = sort_pair.second; + }); } // Argsort implemented C library sort. @@ -314,6 +318,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal sort(input, output, axis, is_ascend); } else if (data_dtype == "int64") { sort(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 50d02012e7d9..4c433d2d89db 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -25,15 +25,14 @@ @tvm.testing.uses_gpu def test_sort(): - def verify_sort(shape, axis, is_ascend, is_dyn=False): - + def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"): if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) + x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype)) else: - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) z = relay.sort(x, axis=axis, is_ascend=is_ascend) func = relay.Function([x], z) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: ref_res = np.sort(x_data, axis=axis) else: @@ -56,6 +55,7 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn) verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn) verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn) + verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn, in_dtype="float16") @tvm.testing.uses_gpu From 52b1f04271305c0409e76a21e5c10abb834aecbd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 6 Aug 2021 18:40:09 +0900 Subject: [PATCH 3/5] support argmax --- src/runtime/contrib/sort/sort.cc | 16 ++++++++++++++++ tests/python/relay/test_op_level6.py | 13 ++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 93f7aaa91b9a..4aa8c92f5199 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -277,6 +277,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } @@ -462,6 +474,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal topk(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 4c433d2d89db..e53c6a1162d5 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -60,14 +60,14 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"): @tvm.testing.uses_gpu def test_argsort(): - def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): + def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float32"): if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) + x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype)) else: - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype) func = relay.Function([x], z) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: ref_res = np.argsort(x_data, axis=axis, kind="stable") else: @@ -85,7 +85,7 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): tvm.testing.assert_allclose(op_res.numpy(), ref_res.astype(dtype), rtol=1e-5) for is_dyn in [False, True]: - for dtype in ["int32", "int64", "float32", "float64", "float16"]: + for dtype in ["int32", "int64", "float32", "float64"]: verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn) dtype = "int32" @@ -93,6 +93,9 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort( + (1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn, in_dtype="float16" + ) @tvm.testing.uses_gpu From e0616d1149d5ea90d5cb3e7cc3e3f7a646463329 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 7 Aug 2021 17:18:20 +0900 Subject: [PATCH 4/5] try fixing wasm build --- web/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/Makefile b/web/Makefile index 8c4dbc20dadc..34a1b8172484 100644 --- a/web/Makefile +++ b/web/Makefile @@ -18,7 +18,7 @@ TVM_ROOT=$(shell cd ..; pwd) INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ - -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -I$(TVM_ROOT)/3rdparty/compiler-rt .PHONY: clean all rmtypedep preparetest From 751f811927d1f81ad11bab482deaadadf88c48b4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 7 Aug 2021 17:31:15 +0900 Subject: [PATCH 5/5] use pytest main --- tests/python/relay/test_op_level6.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index e53c6a1162d5..f4a4dd4e6134 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -16,9 +16,9 @@ # under the License. """ Support level6 operator test cases. """ +import pytest import numpy as np import tvm -from tvm import te from tvm import relay import tvm.testing @@ -147,6 +147,4 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): if __name__ == "__main__": - test_sort() - test_argsort() - test_topk() + pytest.main([__file__])