From 23a0e32ce3a1b153fb10f91f0d85dfeb54fd0283 Mon Sep 17 00:00:00 2001 From: "Raasz, Pawel" Date: Sun, 17 Mar 2024 07:27:06 +0000 Subject: [PATCH] Correct NF4 <-> floating point deduction --- .../include/openvino/reference/convert.hpp | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/core/reference/include/openvino/reference/convert.hpp b/src/core/reference/include/openvino/reference/convert.hpp index 479e89c6f2e5b3..19872a427f17f5 100644 --- a/src/core/reference/include/openvino/reference/convert.hpp +++ b/src/core/reference/include/openvino/reference/convert.hpp @@ -13,6 +13,15 @@ #include "openvino/core/type/nf4.hpp" namespace ov { + +template +constexpr bool is_nf4_iterator() { + using it = typename std::decay::type; + using T = fundamental_type_for; + return std::is_same>::value || + std::is_same>::value; +} + namespace reference { namespace detail { @@ -33,15 +42,10 @@ void convert(InputIt arg, OutputIt out, const size_t count) { using OUT_T = typename std::iterator_traits::value_type; // Deduce types for NF4 <-> floating point conversion to use quantization. - using From = - typename std::conditional>::value && - !std::is_integral::value, - const float, - IN_T>::type; - using To = typename std::conditional>::value && - !std::is_integral::value, - float, - OUT_T>::type; + using From = typename std:: + conditional() && !std::is_integral::value, const float, IN_T>::type; + using To = + typename std::conditional() && !std::is_integral::value, float, OUT_T>::type; std::transform(arg, arg + count, out, detail::convert); } @@ -69,7 +73,7 @@ void convert(const float16* arg, int8_t* out, size_t count); // Count how many f32 values is out of normal finite numbers range when converted to f16 size_t count_out_of_f16_range(const float* arg, size_t count); -// Convert values from f32 to f16 with claming to f16 min/max when value is out of normal finite numbers range +// Convert values from f32 to f16 with clamping to f16 min/max when value is out of normal finite numbers range void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t count); } // namespace reference } // namespace ov