Skip to content

Commit

Permalink
Correct NF4 <-> floating point deduction
Browse files Browse the repository at this point in the history
  • Loading branch information
praasz committed Mar 18, 2024
1 parent 0ee0ef6 commit 23a0e32
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/core/reference/include/openvino/reference/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
#include "openvino/core/type/nf4.hpp"

namespace ov {

template <class ElementIter>
constexpr bool is_nf4_iterator() {
using it = typename std::decay<ElementIter>::type;
using T = fundamental_type_for<element::nf4>;
return std::is_same<it, element::Iterator<element::nf4, const T>>::value ||
std::is_same<it, element::Iterator<element::nf4, T>>::value;
}

namespace reference {
namespace detail {

Expand All @@ -33,15 +42,10 @@ void convert(InputIt arg, OutputIt out, const size_t count) {
using OUT_T = typename std::iterator_traits<OutputIt>::value_type;

// Deduce types for NF4 <-> floating point conversion to use quantization.
using From =
typename std::conditional<std::is_same<InputIt, element::Iterator<element::nf4, const int8_t>>::value &&
!std::is_integral<OUT_T>::value,
const float,
IN_T>::type;
using To = typename std::conditional<std::is_same<OutputIt, element::Iterator<element::nf4, int8_t>>::value &&
!std::is_integral<IN_T>::value,
float,
OUT_T>::type;
using From = typename std::
conditional<is_nf4_iterator<InputIt>() && !std::is_integral<OUT_T>::value, const float, IN_T>::type;
using To =
typename std::conditional<is_nf4_iterator<OutputIt>() && !std::is_integral<IN_T>::value, float, OUT_T>::type;

std::transform(arg, arg + count, out, detail::convert<From, To>);
}
Expand Down Expand Up @@ -69,7 +73,7 @@ void convert<float16, int8_t>(const float16* arg, int8_t* out, size_t count);
// Count how many f32 values is out of normal finite numbers range when converted to f16
size_t count_out_of_f16_range(const float* arg, size_t count);

// Convert values from f32 to f16 with claming to f16 min/max when value is out of normal finite numbers range
// Convert values from f32 to f16 with clamping to f16 min/max when value is out of normal finite numbers range
void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t count);
} // namespace reference
} // namespace ov

0 comments on commit 23a0e32

Please sign in to comment.