Skip to content

Commit

Permalink
Add F4E2M1FN type: literal support
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 18, 2024
1 parent 87d0056 commit 70ca820
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 61 deletions.
11 changes: 10 additions & 1 deletion xla/hlo/translate/hlo_to_mhlo/tests/import.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ add {

// CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>
%constant.13 = f8e3m4[4] constant({1, 2, 3, 4})

// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN>
%constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4})
}

// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
Expand Down Expand Up @@ -542,7 +545,13 @@ add {
%convert.15 = f8e3m4[4] convert(f32[4] %convert.14)

// CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32>
ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15)
%convert.16 = f32[4] convert(f8e3m4[4] %convert.15)

// CHECK-NEXT: %14 = mhlo.convert %13 : (tensor<4xf32>) -> tensor<4xf4E2M1FN>
%convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16)

// CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32>
ROOT %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17)
}

// CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8>
Expand Down
6 changes: 6 additions & 0 deletions xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ xla::Array<T> ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) {
xla::Array<T> array(shape.dimensions());
if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) {
array.SetValues(dense_attr.getValues<T>());
} else if constexpr (xla::primitive_util::IsMXType(type)) {
// Bitcast MX floating point types from APFloat.
auto values = dense_attr.getValues<llvm::APFloat>();
for (int i = 0; i < values.size(); i++) {
array.data()[i] = T::FromRep(values[i].bitcastToAPInt().getZExtValue());
}
} else {
// The only way to get subbyte integers from getValues() is to get them as
// APInts.
Expand Down
11 changes: 9 additions & 2 deletions xla/hlo/translate/mhlo_to_hlo/tests/export.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ func.func @main() {
// CHECK: f8e3m4[4] constant({1, 2, 3, 4})
%cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4>

// CHECK: f4e2m1fn[4] constant({1, 2, 3, 4})
%cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN>

func.return
}

Expand Down Expand Up @@ -739,7 +742,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32>
%10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4>
%11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32>
func.return %11 : tensor<2xf32>
%12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN>
%13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32>
func.return %13 : tensor<2xf32>
}

// CHECK: ENTRY
Expand All @@ -755,7 +760,9 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]])
// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]])
// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]])
// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])
// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]])
// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]])
// CHECK: ROOT %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]])

// -----

Expand Down
28 changes: 20 additions & 8 deletions xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ bool LiteralProtoHasValues(const LiteralProto& proto) {
!proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() ||
!proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() ||
!proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() ||
!proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() ||
!proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() ||
!proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() ||
!proto.f8e3m4s().empty() || !proto.f16s().empty() ||
!proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() ||
proto.c64s_size() || proto.c128s_size() || proto.preds_size() ||
proto.tuple_literals_size();
!proto.f4e2m1fns().empty() || !proto.f8e3m4s().empty() ||
!proto.f8e4m3b11fnuzs().empty() || !proto.f8e4m3fns().empty() ||
!proto.f8e4m3fnuzs().empty() || !proto.f8e4m3s().empty() ||
!proto.f8e5m2fnuzs().empty() || !proto.f8e5m2s().empty() ||
!proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() ||
proto.f64s_size() || proto.c64s_size() || proto.c128s_size() ||
proto.preds_size() || proto.tuple_literals_size();
}

// Lazy getter for the interned scalar shape in static storage. We reuse this
Expand Down Expand Up @@ -1874,7 +1874,6 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
<< __func__ << " is only supported for dense arrays: " << subshape();
CHECK_EQ(size_bytes_dense(), other.size_bytes_dense());
if (primitive_util::IsSubByteNonPredType(subshape().element_type())) {
CHECK(!primitive_util::IsFloatingPointType(subshape().element_type()));
auto one_array = buffer();
auto two_array = other.buffer();
const int bits_per_element =
Expand Down Expand Up @@ -2259,6 +2258,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case S64:
CopyToRepeatedField(proto->mutable_s64s(), data<int64_t>());
break;
case F4E2M1FN:
*proto->mutable_f4e2m1fns() = std::string(
reinterpret_cast<const char*>(data<tsl::float4_e2m1fn>().data()),
size_bytes_dense());
break;
case F8E5M2:
*proto->mutable_f8e5m2s() = std::string(
reinterpret_cast<const char*>(data<tsl::float8_e5m2>().data()),
Expand Down Expand Up @@ -2445,6 +2449,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
case U64:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64_t>(), proto.u64s()));
break;
case F4E2M1FN: {
const std::string& s(proto.f4e2m1fns());
TF_RET_CHECK(data<tsl::float4_e2m1fn>().size() *
sizeof(tsl::float4_e2m1fn) ==
s.size());
memcpy(untyped_data(), s.data(), s.size());
break;
}
case F8E5M2: {
const std::string& s(proto.f8e5m2s());
TF_RET_CHECK(data<tsl::float8_e5m2>().size() * sizeof(tsl::float8_e5m2) ==
Expand Down
25 changes: 15 additions & 10 deletions xla/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,6 @@ class LiteralBase {
primitive_util::NativeToPrimitiveType<NativeT>();
constexpr int bits_per_element = primitive_util::BitWidth(primitive_type);
if constexpr (bits_per_element < 8) {
static_assert(!primitive_util::IsFloatingPointType(primitive_type));
static_assert(!primitive_util::IsComplexType(primitive_type));
static_assert(8 % bits_per_element == 0);
constexpr int elements_per_byte = 8 / bits_per_element;
Expand All @@ -598,9 +597,9 @@ class LiteralBase {
for (int64_t i = 0; i < bytes; ++i) {
uint8_t byte = 0;
for (int b = 0; b < elements_per_byte; ++b) {
uint8_t src =
static_cast<uint8_t>(elements[i * elements_per_byte + b]) &
LsbMask<uint8_t>(bits_per_element);
uint8_t src = Eigen::numext::bit_cast<uint8_t>(
elements[i * elements_per_byte + b]) &
LsbMask<uint8_t>(bits_per_element);
byte |= src << (b * bits_per_element);
}
WriteElement(byte);
Expand All @@ -609,9 +608,9 @@ class LiteralBase {
if (rest != 0) {
uint8_t byte = 0;
for (int64_t b = 0; b < rest; ++b) {
uint8_t src =
static_cast<uint8_t>(elements[bytes * elements_per_byte + b]) &
LsbMask<uint8_t>(bits_per_element);
uint8_t src = Eigen::numext::bit_cast<uint8_t>(
elements[bytes * elements_per_byte + b]) &
LsbMask<uint8_t>(bits_per_element);
byte |= src << (b * bits_per_element);
}
WriteElement(byte);
Expand Down Expand Up @@ -701,10 +700,16 @@ class LiteralBase {
primitive_util::NativeToPrimitiveType<NativeT>();
constexpr int bits_per_element = primitive_util::BitWidth(primitive_type);
if constexpr (bits_per_element < 8) {
static_assert(!primitive_util::IsFloatingPointType(primitive_type));
static_assert(!primitive_util::IsComplexType(primitive_type));
static_assert(8 % bits_per_element == 0);

constexpr int elements_per_byte = 8 / bits_per_element;
constexpr auto cast = [](uint8_t x) -> NativeT {
if constexpr (primitive_util::IsFloatingPointType(primitive_type)) {
return Eigen::numext::bit_cast<NativeT>(x);
}
return static_cast<NativeT>(x);
};

int64_t bytes = elements.size() / elements_per_byte;
for (int64_t i = 0; i < bytes; ++i) {
Expand All @@ -714,7 +719,7 @@ class LiteralBase {
}
for (int b = 0; b < elements_per_byte; ++b) {
elements[i * elements_per_byte + b] =
static_cast<NativeT>(byte & LsbMask<uint8_t>(bits_per_element));
cast(byte & LsbMask<uint8_t>(bits_per_element));
byte >>= bits_per_element;
}
}
Expand All @@ -726,7 +731,7 @@ class LiteralBase {
}
for (int64_t b = 0; b < rest; ++b) {
elements[bytes * elements_per_byte + b] =
static_cast<NativeT>(byte & LsbMask<uint8_t>(bits_per_element));
cast(byte & LsbMask<uint8_t>(bits_per_element));
byte >>= bits_per_element;
}
}
Expand Down
47 changes: 27 additions & 20 deletions xla/literal_comparison_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,30 @@ namespace {
template <typename T>
class LiteralComparisonTest : public ::testing::Test {};

using TestedTypes =
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fn,
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2>;
using TestedTypes = ::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4,
tsl::float8_e4m3, tsl::float8_e4m3b11fnuz,
tsl::float8_e4m3fn, tsl::float8_e4m3fnuz,
tsl::float8_e5m2, tsl::float8_e5m2fnuz>;
TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes);

TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) {
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0),
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) {
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
float expV = 9.0; // F8E4M3*
if (type == F8E5M2)
expV = 10.0;
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
float expV = 1.125; // F8E4M3*
if (type == F8E5M2 || type == F8E5M2FNUZ)
expV = 1.25;
else if (type == F8E3M4)
expV = 8.5;
expV = 1.0625;
else if (type == F4E2M1FN)
expV = 1.5;
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam{expV});
auto error_spec = ErrorSpec(0.0, 0.0);
EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec,
Expand All @@ -64,12 +67,14 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) {

TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) {
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
float expV = 12.0; // F8E4M3*
if (type == F8E5M2)
expV = 14.0;
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
float expV = 1.5; // F8E4M3*
if (type == F8E5M2 || type == F8E5M2FNUZ)
expV = 1.75;
else if (type == F8E3M4)
expV = 10.0;
expV = 1.25;
else if (type == F4E2M1FN)
expV = 4.0;
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam{expV});
auto error_spec = ErrorSpec(0.0, 0.0);
error_spec.low_precision_fp_error_spec.type = type;
Expand All @@ -86,12 +91,14 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) {

TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) {
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
auto actual = LiteralUtil::CreateR0<float>(8.0);
float expV = 12.1; // F8E4M3*
if (type == F8E5M2)
expV = 13.0;
auto actual = LiteralUtil::CreateR0<float>(1.0);
float expV = 1.51; // F8E4M3*
if (type == F8E5M2 || type == F8E5M2FNUZ)
expV = 1.76;
else if (type == F8E3M4)
expV = 10.125;
expV = 1.26;
else if (type == F4E2M1FN)
expV = 4.1;
auto expected = LiteralUtil::CreateR0<float>(expV);
auto error_spec = ErrorSpec(0.0, 0.0);
error_spec.low_precision_fp_error_spec.type = type;
Expand Down
Loading

0 comments on commit 70ca820

Please sign in to comment.