Skip to content

Commit

Permalink
ConvertFloat32ToFloat16: Use DirectXMath conversion functions (micros…
Browse files Browse the repository at this point in the history
…oft#4855)

Custom half <-> float conversion functions had problems in multiple scenarios.  This PR changes them into a wrapper, using the DirectXMath conversion functions instead.
  • Loading branch information
tex3d authored Dec 15, 2022
1 parent 5c4d3b6 commit 6acd11b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 85 deletions.
88 changes: 3 additions & 85 deletions include/dxc/Test/HlslTestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,91 +406,9 @@ inline bool isnanFloat16(uint16_t val) {
(val & FLOAT16_BIT_MANTISSA) != 0;
}

inline uint16_t ConvertFloat32ToFloat16(float val) {
union Bits {
uint32_t u_bits;
float f_bits;
};

static const uint32_t SignMask = 0x8000;

// Minimum f32 value representable in f16 format without denormalizing
static const uint32_t Min16in32 = 0x38800000;

// Maximum f32 value (next to infinity)
static const uint32_t Max32 = 0x7f7FFFFF;

// Mask for f32 mantissa
static const uint32_t Fraction32Mask = 0x007FFFFF;

// pow(2,24)
static const uint32_t DenormalRatio = 0x4B800000;

static const uint32_t NormalDelta = 0x38000000;

Bits bits;
bits.f_bits = val;
uint32_t sign = bits.u_bits & (SignMask << 16);
Bits Abs;
Abs.u_bits = bits.u_bits ^ sign;

bool isLessThanNormal = Abs.f_bits < *(const float*)&Min16in32;
bool isInfOrNaN = Abs.u_bits > Max32;

if (isLessThanNormal) {
// Compute Denormal result
return (uint16_t)(Abs.f_bits * *(const float*)(&DenormalRatio)) | (uint16_t)(sign >> 16);
}
else if (isInfOrNaN) {
// Compute Inf or Nan result
uint32_t Fraction = Abs.u_bits & Fraction32Mask;
uint16_t IsNaN = Fraction == 0 ? 0 : 0xffff;
return (IsNaN & FLOAT16_BIT_MANTISSA) | FLOAT16_BIT_EXP | (uint16_t)(sign >> 16);
}
else {
// Compute Normal result
return (uint16_t)((Abs.u_bits - NormalDelta) >> 13) | (uint16_t)(sign >> 16);
}
}

inline float ConvertFloat16ToFloat32(uint16_t x) {
union Bits {
float f_bits;
uint32_t u_bits;
};

uint32_t Sign = (x & FLOAT16_BIT_SIGN) << 16;

// nan -> exponent all set and mantisa is non zero
// +/-inf -> exponent all set and mantissa is zero
// denorm -> exponent zero and significand nonzero
uint32_t Abs = (x & 0x7fff);
uint32_t IsNormal = Abs > FLOAT16_BIGGEST_DENORM;
uint32_t IsInfOrNaN = Abs > FLOAT16_BIGGEST_NORMAL;

// Signless Result for normals
uint32_t DenormRatio = 0x33800000;
float DenormResult = Abs * (*(float*)&DenormRatio);

uint32_t AbsShifted = Abs << 13;
// Signless Result for normals
uint32_t NormalResult = AbsShifted + 0x38000000;
// Signless Result for int & nans
uint32_t InfResult = AbsShifted + 0x70000000;

Bits bits;
bits.u_bits = 0;
if (IsInfOrNaN)
bits.u_bits |= InfResult;
else if (IsNormal)
bits.u_bits |= NormalResult;
else
bits.f_bits = DenormResult;
bits.u_bits |= Sign;
return bits.f_bits;
}
uint16_t ConvertFloat32ToFloat16(float val);
float ConvertFloat16ToFloat32(uint16_t val);
// These are defined in ShaderOpTest.cpp using DirectXPackedVector functions.
uint16_t ConvertFloat32ToFloat16(float val) throw();
float ConvertFloat16ToFloat32(uint16_t val) throw();

inline bool CompareFloatULP(const float &fsrc, const float &fref, int ULPTolerance,
hlsl::DXIL::Float32DenormMode mode = hlsl::DXIL::Float32DenormMode::Any) {
Expand Down
8 changes: 8 additions & 0 deletions tools/clang/unittests/HLSL/ShaderOpTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <stdlib.h>
#include <DirectXMath.h>
#include <DirectXPackedVector.h>
#include <intsafe.h>
#include <strsafe.h>
#include <xmllite.h>
Expand All @@ -40,6 +41,13 @@
///////////////////////////////////////////////////////////////////////////////
// Useful helper functions.

uint16_t ConvertFloat32ToFloat16(float Value) throw() {
return DirectX::PackedVector::XMConvertFloatToHalf(Value);
}
float ConvertFloat16ToFloat32(uint16_t Value) throw() {
return DirectX::PackedVector::XMConvertHalfToFloat(Value);
}

static st::OutputStringFn g_OutputStrFn;
static void * g_OutputStrFnCtx;

Expand Down

0 comments on commit 6acd11b

Please sign in to comment.