From 1484d5a691a68d0edc5c87ee656d2adf4a9d4039 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Wed, 18 Dec 2024 12:44:07 -0800 Subject: [PATCH] PR #19096: Add F4E2M1FN and F8E8M0FNU types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/19096 This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](https://github.com/openxla/xla/discussions/18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - https://github.com/openxla/stablehlo/pull/2582 - https://github.com/jax-ml/ml_dtypes/pull/181 - https://github.com/llvm/llvm-project/pull/95392 - https://github.com/llvm/llvm-project/pull/108877 - https://github.com/jax-ml/ml_dtypes/pull/166 - https://github.com/llvm/llvm-project/pull/107127 - https://github.com/llvm/llvm-project/pull/111028 The PR is split into multiple commits just to make the review easier, it is possible that some tests could fail if only some (i.e. not all) of these commits are applied. Copybara import of the project: -- f493e4803eaa5ff3da3ceb130e9348c014b4a2e8 by Sergey Kozub : Add F4E2M1FN type: import mxfloat.h -- 87d005630b310a355d7c30b22828c35237373f17 by Sergey Kozub : Add F4E2M1FN type: primitive type -- 70ca82093faeec98f2dc5e8b82f617d99ca96849 by Sergey Kozub : Add F4E2M1FN type: literal support -- c479f0940da490e9668e2f48e14a7466f0c4a97f by Sergey Kozub : Add F4E2M1FN type: conversion codegen -- daaa3af3ce3af456f2ef44dbc291ebeb09e86d9b by Sergey Kozub : Add F4E2M1FN type: python interface -- 1f0e19ff14733eff790726936b68ef0cf607a766 by Sergey Kozub : Add F4E2M1FN type: FFI -- 999bf96092e57c7b3039811f2887281f347ff17a by Sergey Kozub : Add F4E2M1FN type: HLO evaluator -- d7d5af74c5f8a94522779a121c0a4a962156fb64 by Sergey Kozub : Add F4E2M1FN type: add tests -- 9e8c7bc02849f241d0f05941221d99f1d08d9e67 by Sergey Kozub : Add F8E8M0FNU type -- 1e344174b931cea4978770ab740dfed67186c2f4 by Sergey Kozub : Addressing PR#19096 review comments -- d4de0a369d9dc853f34f3cf3bf7dcc5a47502106 by Sergey Kozub : Addressing PR#19096 review comments (round 2) Merging this change closes #19096 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19096 from openxla:skozub/e2m1 d4de0a369d9dc853f34f3cf3bf7dcc5a47502106 PiperOrigin-RevId: 707638099 --- third_party/tsl/tsl/platform/BUILD | 1 + third_party/tsl/tsl/platform/ml_dtypes.h | 3 + xla/array2d_test.cc | 28 ++ .../codegen/transforms/expand_float_ops.cc | 191 ++++++----- .../gpu/codegen/transforms/lower_tensors.cc | 59 ++-- .../transforms/tests/expand_float_ops.mlir | 50 +++ .../transforms/tests/lower_tensors.mlir | 42 ++- xla/comparison_util.h | 9 +- xla/ffi/api/api.h | 4 + xla/ffi/api/c_api.h | 2 + xla/ffi/api/ffi.h | 6 + xla/ffi/api/ffi_test.cc | 6 + xla/ffi/call_frame.cc | 2 + xla/fp_util_test.cc | 70 +++++ xla/hlo/builder/lib/math.cc | 11 +- xla/hlo/builder/lib/math_test.cc | 32 +- xla/hlo/evaluator/BUILD | 1 + xla/hlo/evaluator/hlo_evaluator.cc | 2 +- .../evaluator/hlo_evaluator_typed_visitor.h | 2 + .../hlo_evaluator_typed_visitor_mxfloat.cc | 23 ++ .../expanders/comparison_expander.cc | 59 ++-- .../simplifiers/float_normalization.cc | 3 + .../simplifiers/float_normalization_test.cc | 4 +- xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc | 20 ++ .../translate/hlo_to_mhlo/tests/import.hlo | 20 +- .../translate/mhlo_to_hlo/literal_exporter.cc | 6 + .../translate/mhlo_to_hlo/tests/export.mlir | 18 +- xla/literal.cc | 36 ++- xla/literal.h | 29 +- xla/literal_comparison.cc | 7 +- xla/literal_comparison_test.cc | 52 +-- xla/literal_test.cc | 75 +++-- xla/mlir/utils/type_util.cc | 10 +- xla/mlir/utils/type_util_test.cc | 2 + xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 14 + xla/pjrt/c/CHANGELOG.md | 3 + xla/pjrt/c/pjrt_c_api.h | 6 +- xla/pjrt/c/pjrt_c_api_helpers.cc | 8 + xla/primitive_util.cc | 12 + xla/primitive_util.h | 80 ++++- xla/primitive_util_test.cc | 134 +++++++- xla/python/ifrt/dtype.cc | 8 + xla/python/ifrt/dtype.h | 6 +- xla/python/ifrt/dtype.proto | 6 + xla/python/ifrt/dtype_test.cc | 86 ++--- xla/python/pjrt_ifrt/pjrt_dtype.cc | 4 + xla/python/py_values.cc | 16 + xla/python/types.cc | 42 +++ xla/python/types.h | 2 + xla/python/xla.cc | 2 + xla/python/xla_client.py | 6 + xla/python/xla_client.pyi | 2 + xla/python/xla_client_test.py | 4 +- xla/python/xla_extension/__init__.pyi | 2 + xla/service/cpu/cpu_compiler.cc | 4 + xla/service/cpu/onednn_memory_util.h | 2 +- xla/service/elemental_ir_emitter.cc | 278 +++++++++++++++- xla/service/elemental_ir_emitter_test.cc | 15 +- xla/service/float8_fnuz_ir_emitter.cc | 17 +- .../gpu/fusions/triton/triton_support_test.cc | 34 +- xla/service/gpu/gpu_compiler.cc | 4 + .../gpu/tests/float_conversions_test.cc | 7 +- xla/service/hlo_verifier.cc | 3 +- xla/service/llvm_ir/llvm_util.cc | 3 + xla/stream_executor/data_type.h | 8 + xla/stream_executor/dnn.cc | 2 + xla/stream_executor/gpu/gpu_blas_lt.cc | 10 + xla/stream_executor/rocm/hip_blas_utils.cc | 6 +- xla/tests/BUILD | 2 + xla/tests/array_elementwise_ops_test.cc | 52 +-- xla/tests/constants_test.cc | 8 +- xla/tests/convert_test.cc | 297 +++++++++++++++++- xla/tools/driver.cc | 21 +- xla/tsl/framework/BUILD | 1 + xla/tsl/framework/type_traits.h | 5 +- xla/tsl/protobuf/dnn.proto | 2 + xla/tsl/python/lib/core/ml_dtypes.cc | 6 + xla/tsl/python/lib/core/ml_dtypes.h | 2 + xla/types.h | 16 + xla/util.cc | 10 + xla/util.h | 25 +- xla/util_test.cc | 28 +- xla/xla_data.proto | 27 +- 83 files changed, 1856 insertions(+), 367 deletions(-) create mode 100644 xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc diff --git a/third_party/tsl/tsl/platform/BUILD b/third_party/tsl/tsl/platform/BUILD index eb208e6ef65bac..9f3b135245a407 100644 --- a/third_party/tsl/tsl/platform/BUILD +++ b/third_party/tsl/tsl/platform/BUILD @@ -985,6 +985,7 @@ cc_library( deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ], ) diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index a6a1b56af88ad4..a03fa02447f3c6 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -18,8 +18,10 @@ limitations under the License. #include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "ml_dtypes/include/intn.h" // from @ml_dtypes +#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes namespace tsl { +using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn; using float8_e3m4 = ::ml_dtypes::float8_e3m4; using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; @@ -27,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; +using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; using int1 = ::ml_dtypes::int1; using uint1 = ::ml_dtypes::uint1; diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 921da30256fa3d..c62f6e882713e5 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF4E2M1FN) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + +TEST(Array2dTest, LinspaceF8E8M0FNU) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 2.0); // 1.5 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 4.0); // 3.0 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc index 81cb99d66f82d9..ff2ce862277980 100644 --- a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc +++ b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc @@ -163,7 +163,13 @@ int GetSignificandBits(mlir::FloatType ty) { } int GetExponentBias(mlir::FloatType ty) { - return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()); + return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) - + ty.isFloat8E8M0FNU(); // No zero exponent for E8M0. +} + +bool IsFNUZ(mlir::FloatType ty) { + return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E5M2FNUZ(); } Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -175,7 +181,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { return b.create(ma::CmpFPredicate::OEQ, value, inf); } - assert(ty.getIntOrFloatBitWidth() == 8); + assert(ty.getIntOrFloatBitWidth() <= 8); // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. if (ty.isFloat8E5M2()) { Val bits{b.create(b.getI8Type(), value), &b}; @@ -196,6 +202,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { if (mlir::LLVM::isCompatibleOuterType(ty)) { return b.create(ma::CmpFPredicate::UNO, value, value); } + if (ty.isFloat4E2M1FN()) { + return b.create(false, b.getI1Type()); + } assert(ty.getIntOrFloatBitWidth() == 8); Val bits{b.create(b.getI8Type(), value), &b}; @@ -207,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { return (bits & 0b0111'1111) == 0b0111'1111; } else if (ty.isFloat8E3M4()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); + } else if (ty.isFloat8E8M0FNU()) { + return bits == 0xFF; } return bits == 0x80; } @@ -281,11 +292,18 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth()); mlir::IntegerType wide_int_ty; - if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) { + if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) { wide_int_ty = b.getI16Type(); } else { wide_int_ty = b.getIntegerType( std::max(from_int_ty.getWidth(), to_int_ty.getWidth())); + // Avoid overflow for bit shifts. + auto may_overflow = [&](mlir::Type a, mlir::Type b) { + return a.isFloat8E8M0FNU() && b.isF16(); + }; + if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) { + wide_int_ty = b.getI32Type(); + } } auto convert_int = [&](mlir::Type ty, Value v) -> Val { if (v.getType() == ty) { @@ -300,34 +318,49 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, int64_t exp_offset = to_bias - from_bias; int digit_shift = to_mantissa - from_mantissa; - Val from_bits{ - b.create( - b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value), - &b}; + int from_width = value.getType().getIntOrFloatBitWidth(); + Val from_bits{b.create(b.getIntegerType(from_width), value), + &b}; + if (from_width < 8) { + from_bits = convert_int(b.getIntegerType(8), from_bits); + } auto cst = [&](mlir::Type ty, int64_t n) -> Val { return {b.create(n, ty), &b}; }; // Shift bits to destination type, without sign bit. - Val from_sign_bit = - from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0; - - from_bits = - from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); - - Value result_is_inf = IsInf(value, b); - Value input_is_nan = IsNaN(value, b); + Val from_sign_bit; + if (!from_ty.isFloat8E8M0FNU()) { + from_sign_bit = from_bits.shrui(from_width - 1) != 0; + from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); + } auto cst_bits = [&](llvm::APFloat f) { return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())), f.bitcastToAPInt().getZExtValue()); }; - Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); - Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); - Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + Value to_nan; + Value to_inf; + Val to_zero; + + // MX float types have neither infinities nor NaNs. + if (to_ty.isFloat4E2M1FN()) { + to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + to_nan = to_zero | 0x8; + to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); + } else if (to_ty.isFloat8E8M0FNU()) { + to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); + to_inf = to_nan; + to_zero = Val{to_nan, &b}; + } else { + to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); + to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); + to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + } - auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) { + auto round_bits_to_nearest_even = [&](Val bits, Val roundoff, + bool use_implicit_bit = false) { assert(bits.value.getType() == roundoff.value.getType()); // Round to nearest even by adding a bias term. // Consider a bit pattern @@ -337,9 +370,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // - L is 1, R is 1, OR // - L is 0, R is 1, any T is one. // We do this by adding L to a bit pattern consisting of all T = 1. - Val rounded = (bits.shrui(roundoff) & 1) + - (bits.MakeConstant(1).shl(roundoff - 1) - 1); - Val bias{b.create(roundoff == 0, roundoff, rounded), &b}; + Val bias = !use_implicit_bit + ? (bits.shrui(roundoff) & 1) + + (bits.MakeConstant(1).shl(roundoff - 1) - 1) + : bits.MakeConstant(1).shl(roundoff - 1); return bits + bias; }; @@ -349,9 +383,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // Round the mantissa if it is shrinking. Val rounded_from_bits = convert_int(wide_int_ty, from_bits); if (digit_shift < 0) { - rounded_from_bits = round_bits_to_nearest_even( - from_bits, from_bits.MakeConstant(-digit_shift)) & - ~((1ll << (-digit_shift)) - 1); + rounded_from_bits = + round_bits_to_nearest_even( + rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0) & + ~((1ll << (-digit_shift)) - 1); } // Re-bias the exponent. @@ -394,10 +430,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Val bits = convert_int(wide_int_ty, from_bits); // Determine exponent in target type. - Value normalization_factor = - convert_int(i32_ty, - b.create(from_bits)) - - (from_int_ty.getWidth() - from_mantissa - 1); + Value clz = convert_int( + i32_ty, b.create(from_bits)); + Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz; + Value normalization_factor = cst(i32_ty, from_mantissa) - msb; Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor; // If the result is subnormal, adjust the subnormal bits to account for @@ -418,10 +454,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0); bits.value = b.create(biased_exp_sle_zero, subnormal_bits, normal_bits); - if (digit_shift > 0) { + if (digit_shift >= 0) { bits = bits.shl(digit_shift); } else { - bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift)); + bits = round_bits_to_nearest_even( + bits, bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0); bits = bits.shrui(-digit_shift); } bits = convert_int(to_int_ty, bits); @@ -430,11 +468,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, } else if (to_min_exp > from_min_exp) { // `To` supports fewer exponents near zero which means that some values in // `From` may become subnormal. - Val unbiased_exp = biased_from_exp - from_bias; - Val biased_to_exp = unbiased_exp + to_bias; + Val biased_to_exp = biased_from_exp + (to_bias - from_bias); // Subnormals and zero. // Round and shift mantissa down. - Val from_has_leading_one = biased_from_exp != 0; + Val from_has_leading_one = + !from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1); Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one); from_has_leading_one = convert_int(from_int_ty, from_has_leading_one); Val exponent_shift_i32 = @@ -469,31 +507,35 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, result); } - // Handle types with no unsigned zero. - auto is_nuz = [](mlir::FloatType ty) { - return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E5M2FNUZ(); - }; + Value result_is_inf = IsInf(value, b); + Value input_is_nan = IsNaN(value, b); - if (is_nuz(to_ty)) { + if (to_ty.isFloat8E8M0FNU()) { + // Converting a negative number to E8M0 results in NaN. + input_is_nan = from_sign_bit | input_is_nan; + } else if (IsFNUZ(to_ty)) { // Clear the sign bit if the result is zero (the output has no negative - // zero). - Val result_is_non_zero = Val{result, &b} != 0; + // zero). Handle the edge case when the input is zero and the result is not. + Val result_is_non_zero = + (digit_shift > 0 ? from_bits : Val{result, &b}) != 0; from_sign_bit = from_sign_bit & result_is_non_zero; - } else if (is_nuz(from_ty)) { + } else if (IsFNUZ(from_ty)) { // Clear the sign bit if the input is NaN (it's positive but encoded as // negative 0). from_sign_bit = from_sign_bit ^ input_is_nan; } + if (!from_ty.isFloat8E8M0FNU()) { + result = b.create(from_bits == 0, to_zero, result); + } result = b.create(result_is_inf, to_inf, result); - result = b.create(from_bits == 0, to_zero, result); result = b.create(input_is_nan, to_nan, result); - Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); - // Insert sign bit. - result = b.create(from_sign_bit, neg_result, result); + if (!from_ty.isFloat8E8M0FNU()) { + Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); + result = b.create(from_sign_bit, neg_result, result); + } result = b.create(to_ty, result); return result; } @@ -506,8 +548,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (dst_ty.getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit truncf"); + if (dst_ty.getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -524,8 +566,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (src.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit extf"); + if (src.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -544,8 +586,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { auto lhs = mlir::cast(op.getLhs()); auto rhs = mlir::cast(op.getRhs()); - if (lhs.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf"); + if (lhs.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -553,16 +595,16 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics()); if (op.getPredicate() == ma::CmpFPredicate::UNE && mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) { - Val int_value{b.create(rewriter.getI8Type(), lhs), &b}; + mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth()); + Val int_value{b.create(int_ty, lhs), &b}; int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. - if (rhs_cst.isZero() && - (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || - lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { - int_value = int_value & 0x7f; - constant &= 0x7f; + if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) { + int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1; + int_value = int_value & mask; + constant &= mask; } - auto cst = b.create(constant, rewriter.getI8Type()); + auto cst = b.create(constant, int_ty); rewriter.replaceOpWithNewOp(op, ma::CmpIPredicate::ne, int_value, cst); return mlir::success(); @@ -586,18 +628,23 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern { auto src = mlir::cast(op.getOperand()); // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16. // Once that's removed, remove the code for BF16 here. - if (src.getType().getWidth() != 8 && !src.getType().isBF16()) { - return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf"); + if (src.getType().getWidth() > 8 && !src.getType().isBF16()) { + return rewriter.notifyMatchFailure(op, + "not an f8 (or less) or bf16 absf"); } + + // If type is unsigned (E8M0), the operation is no-op. + if (!llvm::APFloat::semanticsHasSignedRepr( + src.getType().getFloatSemantics())) { + rewriter.replaceAllOpUsesWith(op, op.getOperand()); + return mlir::success(); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); Val value{b.create(i_ty, src), &b}; - if (src.getType().getWidth() == 8) { - value = value & 0x7f; - } else { - CHECK(src.getType().isBF16()); - value = value & 0x7fff; - } + int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1; + value = value & mask; rewriter.replaceOpWithNewOp(op, src.getType(), value); return mlir::success(); } @@ -609,8 +656,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 itofp"); + if (op.getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp"); } Value to_float = rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); @@ -625,8 +672,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getIn().getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 fptoi"); + if (op.getIn().getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi"); } Value to_f32 = rewriter.create( op.getLoc(), rewriter.getF32Type(), op.getIn()); diff --git a/xla/backends/gpu/codegen/transforms/lower_tensors.cc b/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 38e3671f9613f1..31737323d78e4a 100644 --- a/xla/backends/gpu/codegen/transforms/lower_tensors.cc +++ b/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -297,7 +297,8 @@ std::tuple GetI4IndexAndNibble(Value linear_index, mlir::LLVM::GEPOp CreateGep(TypedValue tensor, Value linear_index, mlir::ImplicitLocOpBuilder& b) { Type element_type = tensor.getType().getElementType(); - if (element_type == b.getI4Type()) { + if (element_type.isIntOrFloat() && + element_type.getIntOrFloatBitWidth() == 4) { element_type = b.getI8Type(); } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); @@ -326,7 +327,8 @@ struct RewriteTensorExtract : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); Value is_low_nibble = nullptr; - if (element_type == rewriter.getI4Type()) { + if (element_type.isIntOrFloat() && + element_type.getIntOrFloatBitWidth() == 4) { std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); } @@ -341,7 +343,7 @@ struct RewriteTensorExtract : OpRewritePattern { auto high_value = b.create( load, b.create(4, load.getType())); load = b.create( - op.getType(), + rewriter.getI4Type(), b.create(is_low_nibble, load, high_value)); } @@ -377,6 +379,7 @@ struct RewriteTransferRead : OpRewritePattern { auto source = mlir::dyn_cast>( op.getSource()); + mlir::Type source_element_type = source.getType().getElementType(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); @@ -385,7 +388,9 @@ struct RewriteTransferRead : OpRewritePattern { if (vector_type.getElementType().isInteger(1)) { vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type()); } - if (op.getVectorType().getElementType().isInteger(4)) { + mlir::Type gep_element_type = vector_type.getElementType(); + if (gep_element_type.isIntOrFloat() && + gep_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern { auto loaded = b.create(llvm_vector_type, gep).getResult(); - if (source.getType().getElementType().isInteger(1)) { + if (source_element_type.isInteger(1)) { Value zero = b.create( mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0))); loaded = b.create(arith::CmpIPredicate::ne, loaded, zero); - } else if (source.getType().getElementType().isInteger(4)) { + } else if (source_element_type.isIntOrFloat() && + source_element_type.getIntOrFloatBitWidth() == 4) { // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the // elements. loaded = PermutePairsInVector(loaded, b); @@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern { auto scalar_value = op.getScalar(); // For i4 we store 2 values into one byte. This needs special handling here. - if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) { + if (tensor_dest.getType().getElementType().isIntOrFloat() && + tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) { // We need to use directly op.getDest() as input, otherwise the following // rewrite might remove the only user of it. tensor_dest = op.getDest(); @@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern { auto tensor_dest_i8 = b.create(tensor_ty, tensor_dest) .getResult(0); + if (scalar_value.getType() != rewriter.getI4Type()) { + scalar_value = + b.create(rewriter.getI4Type(), scalar_value); + } scalar_value = b.create(ty, scalar_value); // We need AtomicRMWOp because it can happen that different threads try to @@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::Value vector_value = op.getVector(); - if (op.getVectorType().getElementType().isInteger(1)) { + mlir::Type vector_element_type = op.getVectorType().getElementType(); + if (vector_element_type.isInteger(1)) { vector_value = b.create( op.getVectorType().cloneWith(std::nullopt, b.getI8Type()), vector_value); } - if (op.getVectorType().getElementType().isInteger(4)) { + if (vector_element_type.isIntOrFloat() && + vector_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -577,21 +590,19 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - if (mlir::isa(element_type)) { - int bit_width = mlir::cast(element_type).getWidth(); - if (bit_width == 4) { - num_elements = CeilOfRatio(num_elements, 2); - llvm_element_type = b.getI8Type(); - auto unpacked_data = - mlir::cast(value).getRawData(); - std::vector packed_data(num_elements); - absl::Span packed_data_span = - absl::MakeSpan(packed_data.data(), packed_data.size()); - PackIntN(4, unpacked_data, packed_data_span); - value = mlir::DenseElementsAttr::getFromRawBuffer( - mlir::RankedTensorType::get({num_elements}, llvm_element_type), - packed_data); - } + if (element_type.isIntOrFloat() && + element_type.getIntOrFloatBitWidth() == 4) { + num_elements = CeilOfRatio(num_elements, 2); + llvm_element_type = b.getI8Type(); + auto unpacked_data = + mlir::cast(value).getRawData(); + std::vector packed_data(num_elements); + absl::Span packed_data_span = + absl::MakeSpan(packed_data.data(), packed_data.size()); + PackIntN(4, unpacked_data, packed_data_span); + value = mlir::DenseElementsAttr::getFromRawBuffer( + mlir::RankedTensorType::get({num_elements}, llvm_element_type), + packed_data); } auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements); diff --git a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir index 442fe5e9291572..dea8988d474b05 100644 --- a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir @@ -115,3 +115,53 @@ module { // CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32 // CHECK: arith.truncf %[[EXT]] : f32 to f16 // CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 { + %ret = arith.extf %arg0 : f4E2M1FN to f16 + return %ret : f16 + } +} + +// CHECK-LABEL: @f4_to_f16 +// CHECK-NOT: arith.extf + +// ----- + +module { + func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN { + %ret = arith.truncf %arg0 : f16 to f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f16_to_f4 +// CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN { + %ret = math.absf %arg0 : f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f4_abs +// CHECK-NOT: math.absf +// CHECK: arith.constant 7 : i4 + +// ----- + +module { + func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU { + %ret = math.absf %arg0 : f8E8M0FNU + return %ret : f8E8M0FNU + } +} + +// CHECK-LABEL: @e8m0_abs +// CHECK-NOT: math.absf +// CHECK: return %arg0 diff --git a/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir b/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir index 646c7a00ff756f..864f68d1da6f49 100644 --- a/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir @@ -763,4 +763,44 @@ func.func @for_op(%arg0: tensor<500xf32>) -> f32 { // CHECK-LABEL: @for_op // CHECK: scf.for {{.*}} -> (vector<4xf32>) { -// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { \ No newline at end of file +// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { + +// ----- + +func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN { + %cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN> + %extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN> + %0 = arith.addf %extracted, %extracted_0 : f4E2M1FN + return %0 : f4E2M1FN +} +// CHECK: llvm.mlir.global private constant +// CHECK-SAME: dense<[25, 64]> +// CHECK-LABEL: @f4_constant + +// ----- + +func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0.0 : f4E2M1FN + %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN> + func.return %out : vector<2xf4E2M1FN> +} +// CHECK-LABEL: @transfer_read_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8] +// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4> +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN> +// CHECK: return %[[OUT]] : vector<2xf4E2M1FN> + +// ----- + +func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}, + %arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> { + %c10 = arith.constant 10 : index + %out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN> + func.return %out : tensor<43xf4E2M1FN> +} +// CHECK-LABEL: @transfer_write_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8 +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4> +// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr diff --git a/xla/comparison_util.h b/xla/comparison_util.h index 5a21595da4d741..44f0dd48640bb1 100644 --- a/xla/comparison_util.h +++ b/xla/comparison_util.h @@ -193,8 +193,13 @@ class Comparison { // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN // Reference: // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations - using R = SignedIntegerTypeForSizeType; - return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + if constexpr (std::numeric_limits::is_signed) { + using R = SignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } else { + using R = UnsignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } } } // Applies the comparison from this Comparison's direction and ordering. diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index 389d2d2a9a7aec..9787476f8f7eac 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "C128"; case XLA_FFI_DataType_TOKEN: return os << "TOKEN"; + case XLA_FFI_DataType_F4E2M1FN: + return os << "F4E2M1FN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; case XLA_FFI_DataType_F8E3M4: @@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "F8E5M2FNUZ"; case XLA_FFI_DataType_F8E4M3FNUZ: return os << "F8E4M3FNUZ"; + case XLA_FFI_DataType_F8E8M0FNU: + return os << "F8E8M0FNU"; } } diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index 8d6f1095fad24a..bf8cb7d1a8ad19 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -201,6 +201,8 @@ typedef enum { XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, XLA_FFI_DataType_F8E4M3FNUZ = 25, + XLA_FFI_DataType_F4E2M1FN = 32, + XLA_FFI_DataType_F8E8M0FNU = 33, } XLA_FFI_DataType; // LINT.ThenChange(ffi_test.cc) diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index f264451da34735..aeeab1d505ab66 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -79,6 +79,8 @@ enum class DataType : uint8_t { F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, F8E3M4 = XLA_FFI_DataType_F8E3M4, + F4E2M1FN = XLA_FFI_DataType_F4E2M1FN, + F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; inline constexpr DataType F8E3M4 = DataType::F8E3M4; +inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN; +inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: case DataType::F8E3M4: + case DataType::F4E2M1FN: + case DataType::F8E8M0FNU: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index f09588b9e986a2..3c51a0966ae02e 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); + EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); @@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); + EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU)); } TEST(FfiTest, DataTypeByteWidth) { @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128), ByteWidth(DataType::C128)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN), + ByteWidth(DataType::F4E2M1FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), @@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E4M3FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), ByteWidth(DataType::F8E3M4)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU), + ByteWidth(DataType::F8E8M0FNU)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 3fb2ac3c7786fa..7bcb14da445e8c 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -264,6 +264,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C64: case PrimitiveType::C128: case PrimitiveType::TOKEN: + case PrimitiveType::F4E2M1FN: case PrimitiveType::F8E5M2: case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: @@ -271,6 +272,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: case PrimitiveType::F8E3M4: + case PrimitiveType::F8E8M0FNU: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 3eb7c54f919b0a..8ea22d9d1602bf 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -119,6 +119,76 @@ class FP8E4M3DistanceTest : public ::testing::Test {}; using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F4E2M1FNDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)), + 1); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)), + 2); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 4); +} + +TEST(FPDistanceTest, F8E8M0FNUDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)), + 0); + + // one step apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)), + 1); + + // two steps apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)), + 2); +} + TEST(FPDistanceTest, F8E3M4Distance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index f2a77df3d7ddaa..620e907f8cf112 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F4E2M1FN: case F8E3M4: case F8E4M3: case F8E5M2: @@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index 9755643b7586a0..126ba14f5bb39a 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -95,9 +95,13 @@ class MathTypedTest : public MathTest { Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); bool has_inf = std::numeric_limits::has_infinity; + bool has_nan = std::numeric_limits::has_quiet_NaN; + bool has_finite = !has_inf && !has_nan; + bool has_nan_only = !has_inf && has_nan; + auto expected = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1( - {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1({true, true, true, true, true, has_finite, + has_finite, has_finite, has_finite}), LiteralUtil::CreateR1({false, false, false, false, false, has_inf, has_inf, false, false}), LiteralUtil::CreateR1( @@ -105,7 +109,8 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1( {false, false, false, false, false, false, has_inf, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, - !has_inf, !has_inf, true, true})); + has_nan_only, has_nan_only, has_nan, + has_nan})); ComputeAndCompareLiteral(&b, expected, {}); } @@ -118,10 +123,11 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), &b)); + bool is_mx = std::is_same_v; ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( - {has_negative_zero_v, false, false, false, false, false, false}), + {has_negative_zero_v, false, false, false, false, false, is_mx}), {}, error_spec_); } @@ -136,6 +142,9 @@ class MathTypedTest : public MathTest { // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { SetFastMathDisabled(true); + if (std::is_same_v) { + GTEST_SKIP() << "Skipping due to low precision"; + } // Tests disable constant folding by default, but this test needs it // enabled, otherwise we don't tickle the bug we're trying to catch. @@ -181,9 +190,14 @@ class MathTypedTest : public MathTest { &b); Erf(x); - bool has_inf = std::numeric_limits::has_infinity; - std::vector expected = { - has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)}; + bool inf_as_nan = !std::numeric_limits::has_infinity && + std::numeric_limits::has_quiet_NaN; + std::vector expected = {inf_as_nan ? nan : T(-1), + inf_as_nan ? nan : T(1), + T(-0), + T(0), + T(-1), + T(1)}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } @@ -201,6 +215,10 @@ using TestTypes = #endif #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 double, +#endif +#ifndef XLA_TEST_BACKEND_TPU + // TODO(b/385004399): Run tests on these types on TPU. + tsl::float4_e2m1fn, #endif float>; diff --git a/xla/hlo/evaluator/BUILD b/xla/hlo/evaluator/BUILD index 9159d5db57626d..8884c0da32b12e 100644 --- a/xla/hlo/evaluator/BUILD +++ b/xla/hlo/evaluator/BUILD @@ -36,6 +36,7 @@ cc_library( "hlo_evaluator_typed_visitor_int4.cc", "hlo_evaluator_typed_visitor_int64.cc", "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_mxfloat.cc", "hlo_evaluator_typed_visitor_uint16.cc", "hlo_evaluator_typed_visitor_uint32.cc", "hlo_evaluator_typed_visitor_uint64.cc", diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 35fac878f104da..8e44243823c097 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -3722,7 +3722,7 @@ absl::StatusOr StochasticConvertOp(const Literal& operand_literal, const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { - bool is_negative = static_cast(Eigen::numext::signbit(operand)); + bool is_negative = static_cast(SignAndMagnitude(operand).first); if (Eigen::numext::isinf(operand)) { return is_negative ? std::numeric_limits::min() : std::numeric_limits::max(); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 41cd753d987201..7f0925f1a3179b 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1734,6 +1734,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; @@ -1741,6 +1742,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc new file mode 100644 index 00000000000000..6bc96c1a1f7cda --- /dev/null +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc @@ -0,0 +1,23 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "tsl/platform/ml_dtypes.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/xla/hlo/transforms/expanders/comparison_expander.cc b/xla/hlo/transforms/expanders/comparison_expander.cc index 0f09ecced1ebaf..86d1eeafcd5931 100644 --- a/xla/hlo/transforms/expanders/comparison_expander.cc +++ b/xla/hlo/transforms/expanders/comparison_expander.cc @@ -115,34 +115,41 @@ absl::StatusOr ComparisonExpander::ExpandInstruction( ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); - PrimitiveType signed_type = - primitive_util::SignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); - - auto zero_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); - - auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MinValue(signed_shape.element_type()))); - min_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, min_value, {})); - - auto max_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - max_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, max_value, {})); - - lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, - min_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, - min_value, max_value); + if (compare_type != F8E8M0FNU) { + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(signed_type))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); + } else { + auto int8_shape = ShapeUtil::ChangeElementType(lhs->shape(), U8); + lhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, lhs)); + rhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, rhs)); + } auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( - instruction->shape(), lhs, rhs, compare->direction(), - Comparison::Type::kSigned)); + instruction->shape(), lhs, rhs, compare->direction())); VLOG(2) << "New comparison instruction for total order:" << new_compare->ToString(); diff --git a/xla/hlo/transforms/simplifiers/float_normalization.cc b/xla/hlo/transforms/simplifiers/float_normalization.cc index 88dbd2781ca60f..cf978bf581fcde 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -217,6 +217,9 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) { if (subshape->element_type() == from) { subshape->set_element_type(to); + if (subshape->has_layout() && from == F4E2M1FN) { + subshape->mutable_layout()->set_element_size_in_bits(0); + } } }); float_normalization_->UpdateLayout(hlo->mutable_shape()); diff --git a/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/xla/hlo/transforms/simplifiers/float_normalization_test.cc index 86ec889abc6527..b614f74229c0e5 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -150,7 +150,9 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); + ::testing::Values(F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, + F8E5M2, F8E5M2FNUZ, F8E8M0FNU)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index f70769ea91abec..cea1bc583ea56e 100644 --- a/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -69,6 +70,25 @@ ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral( } return ::mlir::DenseElementsAttr::getFromRawBuffer(type, packed_padded_data); + } else if constexpr (std::is_same_v) { + // DenseElementsAttr::get() does not support being passed an array of + // tsl::float4_e2m1fn. So convert each element to APFloat first. + auto data_span = literal.data(); + std::vector apfloats; + apfloats.reserve(literal.element_count()); + for (size_t i = 0; i < literal.element_count(); i++) { + llvm::APFloat apfloat{static_cast(data_span[i])}; + bool losesInfo; + llvm::APFloat::opStatus status = + apfloat.convert(llvm::APFloat::Float4E2M1FN(), + llvm::APFloat::rmNearestTiesToEven, &losesInfo); + CHECK_EQ(status, llvm::APFloat::opOK) + << "Failed to convert " << data_span[i] << " to Float4E2M1FN APFloat"; + CHECK(!losesInfo) << "Lost info when converting " << data_span[i] + << " to Float4E2M1FN APFloat"; + apfloats.push_back(apfloat); + } + return ::mlir::DenseElementsAttr::get(type, apfloats); } else { auto data_span = literal.data(); return ::mlir::DenseElementsAttr::get( diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 3a1e7ceabb160f..577e4ad61f89e2 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -421,6 +421,12 @@ add { // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + %constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_15:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + %constant.15 = f8e8m0fnu[4] constant({1, 2, 4, 8}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -542,7 +548,19 @@ add { %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> - ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) + %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) + + // CHECK-NEXT: %14 = mhlo.convert %13 : (tensor<4xf32>) -> tensor<4xf4E2M1FN> + %convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16) + + // CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32> + %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17) + + // CHECK-NEXT: %16 = mhlo.convert %15 : (tensor<4xf32>) -> tensor<4xf8E8M0FNU> + %convert.19 = f8e8m0fnu[4] convert(f32[4] %convert.18) + + // CHECK-NEXT: %17 = mhlo.convert %16 : (tensor<4xf8E8M0FNU>) -> tensor<4xf32> + ROOT %convert.20 = f32[4] convert(f8e8m0fnu[4] %convert.19) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc index 821f1487cf88c1..f50e2a097a3277 100644 --- a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc +++ b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc @@ -41,6 +41,12 @@ xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { xla::Array array(shape.dimensions()); if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { array.SetValues(dense_attr.getValues()); + } else if constexpr (xla::primitive_util::IsMXType(type)) { + // Bitcast MX floating point types from APFloat. + auto values = dense_attr.getValues(); + for (int i = 0; i < values.size(); i++) { + array.data()[i] = T::FromRep(values[i].bitcastToAPInt().getZExtValue()); + } } else { // The only way to get subbyte integers from getValues() is to get them as // APInts. diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index a22ec331d93b20..c017751477cb51 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -606,6 +606,12 @@ func.func @main() { // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + // CHECK: f4e2m1fn[4] constant({1, 2, 3, 4}) + %cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + + // CHECK: f8e8m0fnu[4] constant({1, 2, 4, 8}) + %cst_19 = arith.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + func.return } @@ -739,7 +745,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> - func.return %11 : tensor<2xf32> + %12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN> + %13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32> + %14 = "mhlo.convert"(%13) : (tensor<2xf32>) -> tensor<2xf8E8M0FNU> + %15 = "mhlo.convert"(%14) : (tensor<2xf8E8M0FNU>) -> tensor<2xf32> + func.return %15 : tensor<2xf32> } // CHECK: ENTRY @@ -755,7 +765,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) // CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) // CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) -// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) +// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) +// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]]) +// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) +// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]]) +// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 997f44a4dd0f62..866bc1838a9190 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -91,10 +91,11 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || - !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || - !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || - !proto.f8e3m4s().empty() || !proto.f16s().empty() || + !proto.f4e2m1fns().empty() || !proto.f8e3m4s().empty() || + !proto.f8e4m3b11fnuzs().empty() || !proto.f8e4m3fns().empty() || + !proto.f8e4m3fnuzs().empty() || !proto.f8e4m3s().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e5m2s().empty() || + !proto.f8e8m0fnus().empty() || !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || proto.preds_size() || proto.tuple_literals_size(); @@ -1874,7 +1875,6 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(size_bytes_dense(), other.size_bytes_dense()); if (primitive_util::IsSubByteNonPredType(subshape().element_type())) { - CHECK(!primitive_util::IsFloatingPointType(subshape().element_type())); auto one_array = buffer(); auto two_array = other.buffer(); const int bits_per_element = @@ -2259,6 +2259,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case F4E2M1FN: + *proto->mutable_f4e2m1fns() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E5M2: *proto->mutable_f8e5m2s() = std::string( reinterpret_cast(data().data()), @@ -2294,6 +2299,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E8M0FNU: + *proto->mutable_f8e8m0fnus() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2445,6 +2455,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case F4E2M1FN: { + const std::string& s(proto.f4e2m1fns()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float4_e2m1fn) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E5M2: { const std::string& s(proto.f8e5m2s()); TF_RET_CHECK(data().size() * sizeof(tsl::float8_e5m2) == @@ -2498,6 +2516,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E8M0FNU: { + const std::string& s(proto.f8e8m0fnus()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e8m0fnu) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal.h b/xla/literal.h index 0c028bd1aa60ea..db40cd7f650031 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -589,18 +589,17 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { - static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte = 0; for (int b = 0; b < elements_per_byte; ++b) { - uint8_t src = - static_cast(elements[i * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = Eigen::numext::bit_cast( + elements[i * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -609,9 +608,9 @@ class LiteralBase { if (rest != 0) { uint8_t byte = 0; for (int64_t b = 0; b < rest; ++b) { - uint8_t src = - static_cast(elements[bytes * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = Eigen::numext::bit_cast( + elements[bytes * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -701,11 +700,17 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { - static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + constexpr auto cast = [](uint8_t x) -> NativeT { + if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { + return Eigen::numext::bit_cast(x); + } + return static_cast(x); + }; + + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte; @@ -714,7 +719,7 @@ class LiteralBase { } for (int b = 0; b < elements_per_byte; ++b) { elements[i * elements_per_byte + b] = - static_cast(byte & LsbMask(bits_per_element)); + cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } @@ -726,7 +731,7 @@ class LiteralBase { } for (int64_t b = 0; b < rest; ++b) { elements[bytes * elements_per_byte + b] = - static_cast(byte & LsbMask(bits_per_element)); + cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index c97629594122bb..ecea5024963934 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -206,8 +206,8 @@ template std::string FpValueToString(NativeT value) { if constexpr (is_specialized_floating_point_v) { constexpr int kPrecisionDigits = std::numeric_limits::max_digits10; - const int kExponentDigts = - std::ceil(std::log10(std::numeric_limits::max_exponent10)); + const int kExponentDigts = std::ceil( + std::log10(std::max(std::numeric_limits::max_exponent10, 1))); constexpr int kExtraChars = 4; const int kTotalChars = kPrecisionDigits * kExponentDigts + kExtraChars; return absl::StrFormat("%*.*g", kTotalChars, kPrecisionDigits, @@ -418,6 +418,9 @@ class NearComparator { } else { float_distance = CalculateFloatDistance(expected, actual); abs_error = FpAbsoluteValue(actual - expected); + if (!std::numeric_limits::is_signed && IsNaN(abs_error)) { + abs_error = FpAbsoluteValue(expected - actual); + } // Avoid division by 0 even though it's well-defined because ubsan can be // configured to treat this as a fatal error. diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 7713aceaaa3bc5..29c12eb7c75e4a 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -30,13 +30,15 @@ template class LiteralComparisonTest : public ::testing::Test {}; using TestedTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + auto expected = LiteralUtil::CreateR0(TypeParam(1.0)); TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); @@ -44,12 +46,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = 9.0; // F8E4M3* - if (type == F8E5M2) - expV = 10.0; + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + float expV = 1.125; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 1.25; else if (type == F8E3M4) - expV = 8.5; + expV = 1.0625; + else if (type == F4E2M1FN) + expV = 1.5; + else if (type == F8E8M0FNU) + expV = 2.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, @@ -64,12 +70,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = 12.0; // F8E4M3* - if (type == F8E5M2) - expV = 14.0; + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + float expV = 1.5; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 2.0; else if (type == F8E3M4) - expV = 10.0; + expV = 1.25; + else if (type == F4E2M1FN) + expV = 4.0; + else if (type == F8E8M0FNU) + expV = 16.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; @@ -86,12 +96,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(8.0); - float expV = 12.1; // F8E4M3* - if (type == F8E5M2) - expV = 13.0; + auto actual = LiteralUtil::CreateR0(1.0); + float expV = 1.51; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 2.01; else if (type == F8E3M4) - expV = 10.125; + expV = 1.26; + else if (type == F4E2M1FN) + expV = 4.1; + else if (type == F8E8M0FNU) + expV = 16.5; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 44e4acd6a5cef7..7aa9f2dc040dcd 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -124,11 +124,11 @@ class LiteralUtilTest : public ::testing::Test { template class LiteralUtilFloatTest : public LiteralUtilTest {}; -using FloatTypes = - ::testing::Types; +using FloatTypes = ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -175,6 +175,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { LiteralUtil::CreateR0(static_cast(9.001f)); EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); + auto f4e2m1fn_lit = + LiteralUtil::CreateR0(tsl::float4_e2m1fn(0.5)); + EXPECT_EQ("f4e2m1fn[] 0.5", f4e2m1fn_lit.ToString()); + auto f8e5m2_lit = LiteralUtil::CreateR0(tsl::float8_e5m2(0.5)); EXPECT_EQ("f8e5m2[] 0.5", f8e5m2_lit.ToString()); @@ -207,6 +211,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e3m4_lit = LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); + + auto f8e8m0fnu_lit = + LiteralUtil::CreateR0(tsl::float8_e8m0fnu(0.5)); + EXPECT_EQ("f8e8m0fnu[] 0.5", f8e8m0fnu_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -659,6 +667,11 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); + tsl::float4_e2m1fn m16(4); + EXPECT_TRUE(LiteralUtil::CreateR1({m16}).IsAll(4)); + // 5 rounds to 4 in E2M1FN but is not equal to 4, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({m16}).IsAll(5)); + tsl::float8_e5m2 p16(8); EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false @@ -689,6 +702,11 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + tsl::float8_e8m0fnu w16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({w16}).IsAll(8)); + // 9 rounds to 8 in E8M0FNU but is not equal to 8, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({w16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2214,6 +2232,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + using e2m1 = tsl::float4_e2m1fn; + auto vector_f4e2m1fn = + LiteralUtil::CreateR1({e2m1{1.0}, e2m1{2.0}, e2m1{-3.0}}); using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); @@ -2234,6 +2255,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); using e3 = tsl::float8_e3m4; auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); + using e8m0 = tsl::float8_e8m0fnu; + auto vector_f8e8m0fnu = + LiteralUtil::CreateR1({e8m0{1.0}, e8m0{2.0}, e8m0{4.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2254,13 +2278,15 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); - EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f4e2m1fn, to_from_proto(vector_f4e2m1fn)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); - EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); - EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); - EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); + EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e8m0fnu, to_from_proto(vector_f8e8m0fnu)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2511,19 +2537,19 @@ TEST_F(LiteralUtilTest, SliceOnBool) { } TEST_F(LiteralUtilTest, IsEqualAt) { - double val_double = 10.0; - int val_integral = 10; - Literal c1 = LiteralUtil::CreateR0(10); + double val_double = 4.0; + int val_integral = 4; + Literal c1 = LiteralUtil::CreateR0(val_integral); EXPECT_TRUE(c1.IsEqualAt({}, val_double)); EXPECT_TRUE(c1.IsEqualAt({}, val_integral)); - Literal c2 = LiteralUtil::CreateR0(10); + Literal c2 = LiteralUtil::CreateR0(val_double); EXPECT_TRUE(c2.IsEqualAt({}, val_double)); EXPECT_TRUE(c2.IsEqualAt({}, val_integral)); Literal c3 = LiteralUtil::CreateR0(tsl::float8_e5m2{val_double}); EXPECT_TRUE(c3.IsEqualAt({}, val_double)); EXPECT_TRUE(c3.IsEqualAt({}, val_integral)); - complex128 val_complex = {10, 0}; + complex128 val_complex = {val_double, 0}; EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); EXPECT_TRUE(c3.IsEqualAt({}, val_complex)); @@ -2532,8 +2558,8 @@ TEST_F(LiteralUtilTest, IsEqualAt) { EXPECT_TRUE(c4.IsEqualAt({}, val_integral)); EXPECT_TRUE(c4.IsEqualAt({}, val_complex)); EXPECT_FALSE(c4.IsEqualAt({}, std::numeric_limits::infinity())); - complex128 val_true_complex = {10, 3}; - complex64 val_smaller_complex = {10, 3}; + complex128 val_true_complex = {val_double, 3}; + complex64 val_smaller_complex = {static_cast(val_double), 3}; Literal c5 = LiteralUtil::CreateR0(val_true_complex); EXPECT_TRUE(c5.IsEqualAt({}, val_true_complex)); EXPECT_TRUE(c5.IsEqualAt({}, val_smaller_complex)); @@ -2557,6 +2583,14 @@ TEST_F(LiteralUtilTest, IsEqualAt) { LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); EXPECT_TRUE(c10.IsEqualAt({}, val_double)); EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); + Literal c11 = + LiteralUtil::CreateR0(tsl::float4_e2m1fn{val_double}); + EXPECT_TRUE(c11.IsEqualAt({}, val_double)); + EXPECT_TRUE(c11.IsEqualAt({}, val_integral)); + Literal c12 = LiteralUtil::CreateR0( + tsl::float8_e8m0fnu{val_double}); + EXPECT_TRUE(c12.IsEqualAt({}, val_double)); + EXPECT_TRUE(c12.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2882,10 +2916,11 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F8E8M0FNU, + C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 2581390a1e13d7..ea8da4d4990d9d 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -32,6 +32,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( switch (type) { case xla::PrimitiveType::PRED: return b.getI1Type(); + case xla::PrimitiveType::F4E2M1FN: + return b.getFloat4E2M1FNType(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); case xla::PrimitiveType::F8E4M3: @@ -46,6 +48,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E4M3FNUZType(); case xla::PrimitiveType::F8E3M4: return b.getFloat8E3M4Type(); + case xla::PrimitiveType::F8E8M0FNU: + return b.getFloat8E8M0FNUType(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -78,7 +82,9 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( } xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { - if (type.isFloat8E5M2()) { + if (type.isFloat4E2M1FN()) { + return xla::PrimitiveType::F4E2M1FN; + } else if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; } else if (type.isFloat8E4M3()) { return xla::PrimitiveType::F8E4M3; @@ -92,6 +98,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E5M2FNUZ; } else if (type.isFloat8E3M4()) { return xla::PrimitiveType::F8E3M4; + } else if (type.isFloat8E8M0FNU()) { + return xla::PrimitiveType::F8E8M0FNU; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index a8043ab0b5f140..2239943d906b7b 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P( Execute, TypeUtilTest, ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, + {F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, @@ -111,6 +112,7 @@ INSTANTIATE_TEST_SUITE_P( {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, + {F8E8M0FNU, [](mlir::Builder b) { return b.getFloat8E8M0FNUType(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index d07a178c6c4e7f..16a64cdc22b768 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e3m4(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor @@ -6872,6 +6879,13 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- +func.func @f8e8m0fnu(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @top_k_1d(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 5852c9a54dcc01..fe9158f41e337e 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,4 +1,7 @@ # PJRT C API changelog +## 0.61 +* Added types F4E2M1FN and F8E8M0FNU. + ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 36d82b0787ba41..61a1f8785bc581 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 60 +#define PJRT_API_MINOR 61 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -681,6 +681,10 @@ typedef enum { // More truncated 8 bit floating-point formats. PJRT_Buffer_Type_F8E4M3, PJRT_Buffer_Type_F8E3M4, + PJRT_Buffer_Type_F8E8M0FNU, + + // 4-bit MX floating-point format. + PJRT_Buffer_Type_F4E2M1FN, } PJRT_Buffer_Type; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 2060a73a634a48..b1ad44329a40ef 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -310,6 +310,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_BF16; case xla::PrimitiveType::F64: return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; + case xla::PrimitiveType::F4E2M1FN: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; case xla::PrimitiveType::F8E4M3: @@ -324,6 +326,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; case xla::PrimitiveType::F8E3M4: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; + case xla::PrimitiveType::F8E8M0FNU: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -377,6 +381,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C64; case PJRT_Buffer_Type::PJRT_Buffer_Type_C128: return xla::PrimitiveType::C128; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: @@ -391,6 +397,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E4M3FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: return xla::PrimitiveType::F8E3M4; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.cc b/xla/primitive_util.cc index b70ba275a1f47f..5006406ea99779 100644 --- a/xla/primitive_util.cc +++ b/xla/primitive_util.cc @@ -93,6 +93,18 @@ bool HasInfinity(PrimitiveType type) { return false; } +bool HasNaN(PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + return FloatingPointTypeSwitch( + [&](auto constant_type) -> bool { + return std::numeric_limits< + NativeTypeOf>::has_quiet_NaN; + }, + type); + } + return false; +} + bool HasNegativeZero(PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { return FloatingPointTypeSwitch( diff --git a/xla/primitive_util.h b/xla/primitive_util.h index b9c1c978bc620e..70a8335c8bc518 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -69,6 +69,9 @@ int ExponentBias(PrimitiveType type); // Returns whether the type has a value for infinity. bool HasInfinity(PrimitiveType type); +// Returns whether the type has a value for NaN. +bool HasNaN(PrimitiveType type); + // Returns whether the type has a value for negative zero. bool HasNegativeZero(PrimitiveType type); @@ -185,6 +188,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return BF16; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F4E2M1FN; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; @@ -220,6 +228,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E3M4; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E8M0FNU; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -334,6 +347,11 @@ struct PrimitiveTypeToNative { using type = bfloat16; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float4_e2m1fn; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; @@ -369,6 +387,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e3m4; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e8m0fnu; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -401,6 +424,10 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { primitive_type < PrimitiveType_ARRAYSIZE; } +constexpr bool IsMXType(PrimitiveType type) { + return type == F4E2M1FN || type == F8E8M0FNU; +} + constexpr bool IsF8Type(PrimitiveType type) { return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || @@ -409,7 +436,7 @@ constexpr bool IsF8Type(PrimitiveType type) { constexpr bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16 || - IsF8Type(type); + IsF8Type(type) || IsMXType(type); } constexpr bool IsComplexType(PrimitiveType type) { @@ -473,6 +500,9 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F4E2M1FN: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E3M4: return std::forward(f)( PrimitiveTypeConstant()); @@ -494,6 +524,9 @@ constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { case F8E5M2FNUZ: return std::forward(f)( PrimitiveTypeConstant()); + case F8E8M0FNU: + return std::forward(f)( + PrimitiveTypeConstant()); case F16: return std::forward(f)(PrimitiveTypeConstant()); case BF16: @@ -577,6 +610,9 @@ inline constexpr int PrimitiveTypeBitWidth() { if constexpr (primitive_type == PRED) { return std::numeric_limits::digits; } + if constexpr (IsMXType(primitive_type)) { + return NativeT::kBits; + } if constexpr (IsFloatingPointType(primitive_type)) { return sizeof(NativeT) * std::numeric_limits::digits; } @@ -715,6 +751,10 @@ inline bool CastPreservesValues(PrimitiveType from_type, if (from_type == to_type) { return true; } + // * -> F8E8M0FNU is not possible because zero cannot be represented. + if (to_type == F8E8M0FNU) { + return false; + } // PRED -> * if (from_type == PRED) { return true; @@ -737,21 +777,33 @@ inline bool CastPreservesValues(PrimitiveType from_type, return false; } // F -> F is safe if the exponent/significand are preserved and `to_type` - // preserves infinities in `from_type. + // preserves infinities/nans/unsigned zero in `from_type`. if (primitive_util::IsFloatingPointType(from_type) && primitive_util::IsFloatingPointType(to_type)) { - return (!primitive_util::HasInfinity(from_type) || - primitive_util::HasInfinity(to_type)) && - primitive_util::SignificandWidth(from_type) <= - primitive_util::SignificandWidth(to_type) && - primitive_util::ExponentWidth(from_type) <= - primitive_util::ExponentWidth(to_type) && - (primitive_util::UnderflowExponent(from_type) - - primitive_util::SignificandWidth(from_type)) >= - (primitive_util::UnderflowExponent(to_type) - - primitive_util::SignificandWidth(to_type)) && - primitive_util::OverflowExponent(from_type) <= - primitive_util::OverflowExponent(to_type); + return + // Target mantissa should be large enough. + primitive_util::SignificandWidth(from_type) <= + primitive_util::SignificandWidth(to_type) && + // Target exponent should be large enough. + primitive_util::ExponentWidth(from_type) <= + primitive_util::ExponentWidth(to_type) && + // HasInfinity check. + (!primitive_util::HasInfinity(from_type) || + primitive_util::HasInfinity(to_type)) && + // HasNaN check. + (!primitive_util::HasNaN(from_type) || + primitive_util::HasNaN(to_type)) && + // HasNegativeZero check. + (!primitive_util::HasNegativeZero(from_type) || + primitive_util::HasNegativeZero(to_type)) && + // Minimum denormal should be representable by target type. + (primitive_util::UnderflowExponent(from_type) - + primitive_util::SignificandWidth(from_type)) >= + (primitive_util::UnderflowExponent(to_type) - + primitive_util::SignificandWidth(to_type)) && + // Maximum exponent may be larger with custom bias (e.g. F8E4M3B11FNUZ). + primitive_util::OverflowExponent(from_type) <= + primitive_util::OverflowExponent(to_type); } // F -> I is not safe because it drops fractional numbers. if (!primitive_util::IsIntegralType(from_type)) { diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index 190e6442d03263..68fad70096812e 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -69,8 +69,9 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][F8E4M3] = expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = expecteds[PRED][F8E3M4] = true; + expecteds[PRED][F4E2M1FN] = true; + expecteds[PRED][F8E8M0FNU] = false; expecteds[S1][PRED] = false; - expecteds[S2][PRED] = false; expecteds[S1][S1] = true; expecteds[S1][S2] = true; expecteds[S1][S4] = true; @@ -91,6 +92,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][C64] = true; expecteds[S1][BF16] = true; expecteds[S1][C128] = true; + expecteds[S1][F4E2M1FN] = true; expecteds[S1][F8E5M2] = true; expecteds[S1][F8E4M3] = true; expecteds[S1][F8E4M3FN] = true; @@ -98,8 +100,11 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][F8E5M2FNUZ] = true; expecteds[S1][F8E4M3FNUZ] = true; expecteds[S1][F8E3M4] = true; + expecteds[S1][F8E8M0FNU] = false; + expecteds[S2][PRED] = false; expecteds[S2][S1] = false; - expecteds[S2][S2] = expecteds[S2][S4] = true; + expecteds[S2][S2] = true; + expecteds[S2][S4] = true; expecteds[S2][S8] = true; expecteds[S2][S16] = true; expecteds[S2][S32] = true; @@ -117,6 +122,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][C64] = true; expecteds[S2][BF16] = true; expecteds[S2][C128] = true; + expecteds[S2][F4E2M1FN] = true; expecteds[S2][F8E5M2] = true; expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; @@ -124,6 +130,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; expecteds[S2][F8E3M4] = true; + expecteds[S2][F8E8M0FNU] = false; expecteds[S4][PRED] = false; expecteds[S4][S1] = false; expecteds[S4][S2] = false; @@ -145,6 +152,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][C64] = true; expecteds[S4][BF16] = true; expecteds[S4][C128] = true; + expecteds[S4][F4E2M1FN] = false; expecteds[S4][F8E5M2] = true; expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; @@ -152,6 +160,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; expecteds[S4][F8E3M4] = true; + expecteds[S4][F8E8M0FNU] = false; expecteds[S8][PRED] = false; expecteds[S8][S1] = false; expecteds[S8][S2] = false; @@ -173,6 +182,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][C64] = true; expecteds[S8][BF16] = true; expecteds[S8][C128] = true; + expecteds[S8][F4E2M1FN] = false; expecteds[S8][F8E5M2] = false; expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; @@ -180,6 +190,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; expecteds[S8][F8E3M4] = false; + expecteds[S8][F8E8M0FNU] = false; expecteds[S16][PRED] = false; expecteds[S16][S1] = false; expecteds[S16][S2] = false; @@ -201,6 +212,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][C64] = true; expecteds[S16][BF16] = false; expecteds[S16][C128] = true; + expecteds[S16][F4E2M1FN] = false; expecteds[S16][F8E5M2] = false; expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; @@ -208,6 +220,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; expecteds[S16][F8E3M4] = false; + expecteds[S16][F8E8M0FNU] = false; expecteds[S32][PRED] = false; expecteds[S32][S1] = false; expecteds[S32][S2] = false; @@ -229,6 +242,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][C64] = false; expecteds[S32][BF16] = false; expecteds[S32][C128] = true; + expecteds[S32][F4E2M1FN] = false; expecteds[S32][F8E5M2] = false; expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; @@ -236,6 +250,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; expecteds[S32][F8E3M4] = false; + expecteds[S32][F8E8M0FNU] = false; expecteds[S64][PRED] = false; expecteds[S64][S1] = false; expecteds[S64][S2] = false; @@ -257,6 +272,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][C64] = false; expecteds[S64][BF16] = false; expecteds[S64][C128] = false; + expecteds[S64][F4E2M1FN] = false; expecteds[S64][F8E5M2] = false; expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; @@ -264,6 +280,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; expecteds[S64][F8E3M4] = false; + expecteds[S64][F8E8M0FNU] = false; expecteds[U1][PRED] = false; expecteds[U1][S1] = false; expecteds[U1][S2] = true; @@ -285,8 +302,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][C64] = true; expecteds[U1][BF16] = true; expecteds[U1][C128] = true; - expecteds[U1][BF16] = true; - expecteds[U1][C128] = true; + expecteds[U1][F4E2M1FN] = true; expecteds[U1][F8E5M2] = true; expecteds[U1][F8E4M3] = true; expecteds[U1][F8E4M3FN] = true; @@ -294,14 +310,16 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][F8E5M2FNUZ] = true; expecteds[U1][F8E4M3FNUZ] = true; expecteds[U1][F8E3M4] = true; + expecteds[U1][F8E8M0FNU] = false; expecteds[U2][PRED] = false; - expecteds[U2][U1] = expecteds[U2][S1] = false; + expecteds[U2][S1] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; expecteds[U2][S8] = true; expecteds[U2][S16] = true; expecteds[U2][S32] = true; expecteds[U2][S64] = true; + expecteds[U2][U1] = false; expecteds[U2][U2] = true; expecteds[U2][U4] = true; expecteds[U2][U8] = true; @@ -314,8 +332,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][C64] = true; expecteds[U2][BF16] = true; expecteds[U2][C128] = true; - expecteds[U2][BF16] = true; - expecteds[U2][C128] = true; + expecteds[U2][F4E2M1FN] = true; expecteds[U2][F8E5M2] = true; expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; @@ -323,6 +340,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; expecteds[U2][F8E3M4] = true; + expecteds[U2][F8E8M0FNU] = false; expecteds[U4][PRED] = false; expecteds[U4][S1] = false; expecteds[U4][S2] = false; @@ -344,8 +362,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][C64] = true; expecteds[U4][BF16] = true; expecteds[U4][C128] = true; - expecteds[U4][BF16] = true; - expecteds[U4][C128] = true; + expecteds[U4][F4E2M1FN] = false; expecteds[U4][F8E5M2] = false; expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; @@ -353,6 +370,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; expecteds[U4][F8E3M4] = true; + expecteds[U4][F8E8M0FNU] = false; expecteds[U8][PRED] = false; expecteds[U8][S1] = false; expecteds[U8][S2] = false; @@ -374,8 +392,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][C64] = true; expecteds[U8][BF16] = true; expecteds[U8][C128] = true; - expecteds[U8][BF16] = true; - expecteds[U8][C128] = true; + expecteds[U8][F4E2M1FN] = false; expecteds[U8][F8E5M2] = false; expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; @@ -383,6 +400,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; expecteds[U8][F8E3M4] = false; + expecteds[U8][F8E8M0FNU] = false; expecteds[U16][PRED] = false; expecteds[U16][S1] = false; expecteds[U16][S2] = false; @@ -404,6 +422,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][C64] = true; expecteds[U16][BF16] = false; expecteds[U16][C128] = true; + expecteds[U16][F4E2M1FN] = false; expecteds[U16][F8E5M2] = false; expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; @@ -411,6 +430,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; expecteds[U16][F8E3M4] = false; + expecteds[U16][F8E8M0FNU] = false; expecteds[U32][PRED] = false; expecteds[U32][S1] = false; expecteds[U32][S2] = false; @@ -432,6 +452,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][C64] = false; expecteds[U32][BF16] = false; expecteds[U32][C128] = true; + expecteds[U32][F4E2M1FN] = false; expecteds[U32][F8E5M2] = false; expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; @@ -439,6 +460,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; expecteds[U32][F8E3M4] = false; + expecteds[U32][F8E8M0FNU] = false; expecteds[U64][PRED] = false; expecteds[U64][S1] = false; expecteds[U64][S2] = false; @@ -460,6 +482,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][C64] = false; expecteds[U64][BF16] = false; expecteds[U64][C128] = false; + expecteds[U64][F4E2M1FN] = false; expecteds[U64][F8E5M2] = false; expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; @@ -467,6 +490,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; expecteds[U64][F8E3M4] = false; + expecteds[U64][F8E8M0FNU] = false; expecteds[F16][PRED] = false; expecteds[F16][S1] = false; expecteds[F16][S2] = false; @@ -488,6 +512,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][C64] = true; expecteds[F16][BF16] = false; expecteds[F16][C128] = true; + expecteds[F16][F4E2M1FN] = false; expecteds[F16][F8E5M2] = false; expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; @@ -495,6 +520,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; expecteds[F16][F8E3M4] = false; + expecteds[F16][F8E8M0FNU] = false; expecteds[F32][PRED] = false; expecteds[F32][S1] = false; expecteds[F32][S2] = false; @@ -516,6 +542,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][C64] = true; expecteds[F32][BF16] = false; expecteds[F32][C128] = true; + expecteds[F32][F4E2M1FN] = false; expecteds[F32][F8E5M2] = false; expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; @@ -523,6 +550,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; expecteds[F32][F8E3M4] = false; + expecteds[F32][F8E8M0FNU] = false; expecteds[F64][PRED] = false; expecteds[F64][S1] = false; expecteds[F64][S2] = false; @@ -544,6 +572,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][C64] = false; expecteds[F64][BF16] = false; expecteds[F64][C128] = true; + expecteds[F64][F4E2M1FN] = false; expecteds[F64][F8E5M2] = false; expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; @@ -551,6 +580,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; expecteds[F64][F8E3M4] = false; + expecteds[F64][F8E8M0FNU] = false; expecteds[C64][PRED] = false; expecteds[C64][S1] = false; expecteds[C64][S2] = false; @@ -572,6 +602,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][C64] = true; expecteds[C64][BF16] = false; expecteds[C64][C128] = true; + expecteds[C64][F4E2M1FN] = false; expecteds[C64][F8E5M2] = false; expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; @@ -579,6 +610,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; expecteds[C64][F8E3M4] = false; + expecteds[C64][F8E8M0FNU] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S1] = false; expecteds[BF16][S2] = false; @@ -600,6 +632,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][C64] = true; expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; + expecteds[BF16][F4E2M1FN] = false; expecteds[BF16][F8E5M2] = false; expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; @@ -607,6 +640,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; expecteds[BF16][F8E3M4] = false; + expecteds[BF16][F8E8M0FNU] = false; expecteds[C128][PRED] = false; expecteds[C128][S1] = false; expecteds[C128][S2] = false; @@ -628,6 +662,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][C64] = false; expecteds[C128][BF16] = false; expecteds[C128][C128] = true; + expecteds[C128][F4E2M1FN] = false; expecteds[C128][F8E5M2] = false; expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; @@ -635,6 +670,37 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; expecteds[C128][F8E3M4] = false; + expecteds[C128][F8E8M0FNU] = false; + expecteds[F4E2M1FN][PRED] = false; + expecteds[F4E2M1FN][S1] = false; + expecteds[F4E2M1FN][S2] = false; + expecteds[F4E2M1FN][S4] = false; + expecteds[F4E2M1FN][S8] = false; + expecteds[F4E2M1FN][S16] = false; + expecteds[F4E2M1FN][S32] = false; + expecteds[F4E2M1FN][S64] = false; + expecteds[F4E2M1FN][U1] = false; + expecteds[F4E2M1FN][U2] = false; + expecteds[F4E2M1FN][U4] = false; + expecteds[F4E2M1FN][U8] = false; + expecteds[F4E2M1FN][U16] = false; + expecteds[F4E2M1FN][U32] = false; + expecteds[F4E2M1FN][U64] = false; + expecteds[F4E2M1FN][F16] = true; + expecteds[F4E2M1FN][F32] = true; + expecteds[F4E2M1FN][F64] = true; + expecteds[F4E2M1FN][C64] = true; + expecteds[F4E2M1FN][BF16] = true; + expecteds[F4E2M1FN][C128] = true; + expecteds[F4E2M1FN][F4E2M1FN] = true; + expecteds[F4E2M1FN][F8E5M2] = true; + expecteds[F4E2M1FN][F8E4M3] = true; + expecteds[F4E2M1FN][F8E4M3FN] = true; + expecteds[F4E2M1FN][F8E4M3B11FNUZ] = false; + expecteds[F4E2M1FN][F8E4M3FNUZ] = false; + expecteds[F4E2M1FN][F8E5M2FNUZ] = false; + expecteds[F4E2M1FN][F8E3M4] = true; + expecteds[F4E2M1FN][F8E8M0FNU] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S1] = false; expecteds[F8E5M2][S2] = false; @@ -656,6 +722,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][C64] = true; expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; + expecteds[F8E5M2][F4E2M1FN] = false; expecteds[F8E5M2][F8E5M2] = true; expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; @@ -663,6 +730,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E5M2][F8E8M0FNU] = false; expecteds[F8E4M3][PRED] = false; expecteds[F8E4M3][S1] = false; expecteds[F8E4M3][S2] = false; @@ -684,6 +752,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][C64] = true; expecteds[F8E4M3][BF16] = true; expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F4E2M1FN] = false; expecteds[F8E4M3][F8E5M2] = false; expecteds[F8E4M3][F8E5M2FNUZ] = false; expecteds[F8E4M3][F8E4M3] = true; @@ -691,6 +760,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][F8E4M3FNUZ] = false; expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3][F8E3M4] = false; + expecteds[F8E4M3][F8E8M0FNU] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S1] = false; expecteds[F8E4M3FN][S2] = false; @@ -712,6 +782,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][C64] = true; expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; + expecteds[F8E4M3FN][F4E2M1FN] = false; expecteds[F8E4M3FN][F8E5M2] = false; expecteds[F8E4M3FN][F8E5M2FNUZ] = false; expecteds[F8E4M3FN][F8E4M3] = false; @@ -719,6 +790,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][F8E3M4] = false; + expecteds[F8E4M3FN][F8E8M0FNU] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S1] = false; expecteds[F8E4M3B11FNUZ][S2] = false; @@ -740,6 +812,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][C64] = true; expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; + expecteds[F8E4M3B11FNUZ][F4E2M1FN] = false; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; @@ -747,6 +820,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E3M4] = false; + expecteds[F8E4M3B11FNUZ][F8E8M0FNU] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S1] = false; expecteds[F8E5M2FNUZ][S2] = false; @@ -768,6 +842,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][C64] = true; expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; + expecteds[F8E5M2FNUZ][F4E2M1FN] = false; expecteds[F8E5M2FNUZ][F8E5M2] = false; expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; @@ -775,6 +850,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][F8E3M4] = false; + expecteds[F8E5M2FNUZ][F8E8M0FNU] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S1] = false; expecteds[F8E4M3FNUZ][S2] = false; @@ -796,6 +872,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][C64] = true; expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; + expecteds[F8E4M3FNUZ][F4E2M1FN] = false; expecteds[F8E4M3FNUZ][F8E5M2] = false; expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; @@ -803,6 +880,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E4M3FNUZ][F8E8M0FNU] = false; expecteds[F8E3M4][PRED] = false; expecteds[F8E3M4][S1] = false; expecteds[F8E3M4][S2] = false; @@ -824,6 +902,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][C64] = true; expecteds[F8E3M4][BF16] = true; expecteds[F8E3M4][C128] = true; + expecteds[F8E3M4][F4E2M1FN] = false; expecteds[F8E3M4][F8E5M2] = false; expecteds[F8E3M4][F8E5M2FNUZ] = false; expecteds[F8E3M4][F8E4M3] = false; @@ -831,6 +910,37 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][F8E4M3FNUZ] = false; expecteds[F8E3M4][F8E4M3B11FNUZ] = false; expecteds[F8E3M4][F8E3M4] = true; + expecteds[F8E3M4][F8E8M0FNU] = false; + expecteds[F8E8M0FNU][PRED] = false; + expecteds[F8E8M0FNU][S1] = false; + expecteds[F8E8M0FNU][S2] = false; + expecteds[F8E8M0FNU][S4] = false; + expecteds[F8E8M0FNU][S8] = false; + expecteds[F8E8M0FNU][S16] = false; + expecteds[F8E8M0FNU][S32] = false; + expecteds[F8E8M0FNU][S64] = false; + expecteds[F8E8M0FNU][U1] = false; + expecteds[F8E8M0FNU][U2] = false; + expecteds[F8E8M0FNU][U4] = false; + expecteds[F8E8M0FNU][U8] = false; + expecteds[F8E8M0FNU][U16] = false; + expecteds[F8E8M0FNU][U32] = false; + expecteds[F8E8M0FNU][U64] = false; + expecteds[F8E8M0FNU][F16] = false; + expecteds[F8E8M0FNU][F32] = true; + expecteds[F8E8M0FNU][F64] = true; + expecteds[F8E8M0FNU][C64] = true; + expecteds[F8E8M0FNU][BF16] = true; + expecteds[F8E8M0FNU][C128] = true; + expecteds[F8E8M0FNU][F4E2M1FN] = false; + expecteds[F8E8M0FNU][F8E5M2] = false; + expecteds[F8E8M0FNU][F8E4M3] = false; + expecteds[F8E8M0FNU][F8E4M3FN] = false; + expecteds[F8E8M0FNU][F8E4M3B11FNUZ] = false; + expecteds[F8E8M0FNU][F8E4M3FNUZ] = false; + expecteds[F8E8M0FNU][F8E5M2FNUZ] = false; + expecteds[F8E8M0FNU][F8E3M4] = false; + expecteds[F8E8M0FNU][F8E8M0FNU] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { @@ -851,7 +961,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { << primitive_util::LowercasePrimitiveTypeName(to_type); } } -} +} // NOLINT(readability/fn_size) } // namespace } // namespace xla diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index a79240f51a7e23..e1110543cb11ad 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -32,6 +32,7 @@ std::optional DType::byte_size() const { case kU2: case kS4: case kU4: + case kF4E2M1FN: // Smaller than a byte. return std::nullopt; case kPred: @@ -39,6 +40,7 @@ std::optional DType::byte_size() const { case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -77,12 +79,14 @@ std::optional DType::bit_size() const { return 2; case kS4: case kU4: + case kF4E2M1FN: return 4; case kPred: case kS8: case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -141,9 +145,11 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); + CASE(F4E2M1FN); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // CASE(F8E3M4); // CASE(F8E4M3); + CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -189,9 +195,11 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); + CASE(F4E2M1FN); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // CASE(F8E3M4); // CASE(F8E4M3); + CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index d23efc55a1aa12..864cdd1c063ae4 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -88,8 +88,12 @@ class DType { kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, + kF8E8M0FNU = 33, - // Next = 30 + // MX floating point types. + kF4E2M1FN = 32, + + // Next = 34 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index 3a2b0df7976d6e..2cf453f26c291d 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -70,12 +70,18 @@ message DTypeProto { KIND_F8E4M3FNUZ = 25; KIND_F8E5M2 = 19; KIND_F8E5M2FNUZ = 24; + KIND_F8E8M0FNU = 31; + + // MX floating point types. + KIND_F4E2M1FN = 30; // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind // needs to match xla.PrimitiveType enum, so choose a large enum to avoid // collision. KIND_STRING = 99; + + // Next: 32 } // LINT.ThenChange() Kind kind = 1; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 57fec6702d277d..9d3d3105f54e54 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -42,34 +42,21 @@ TEST(DTypeTest, FromToFromProto) { TEST(DTypeTest, ByteSize) { for (const auto& [kind, byte_size] : std::vector>({ - {DType::kS2, -1}, - {DType::kU2, -1}, - {DType::kS4, -1}, - {DType::kU4, -1}, - {DType::kPred, 1}, - {DType::kS8, 1}, - {DType::kU8, 1}, - {DType::kF8E3M4, 1}, - {DType::kF8E4M3, 1}, - {DType::kF8E4M3FN, 1}, - {DType::kF8E4M3B11FNUZ, 1}, - {DType::kF8E4M3FNUZ, 1}, - {DType::kF8E5M2, 1}, - {DType::kF8E5M2FNUZ, 1}, - {DType::kS16, 2}, - {DType::kU16, 2}, - {DType::kF16, 2}, - {DType::kBF16, 2}, - {DType::kS32, 4}, - {DType::kU32, 4}, - {DType::kF32, 4}, - {DType::kS64, 8}, - {DType::kU64, 8}, - {DType::kF64, 8}, - {DType::kC64, 8}, - {DType::kC128, 16}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, + {DType::kS2, -1}, {DType::kU2, -1}, + {DType::kS4, -1}, {DType::kU4, -1}, + {DType::kPred, 1}, {DType::kS8, 1}, + {DType::kU8, 1}, {DType::kF4E2M1FN, -1}, + {DType::kF8E3M4, 1}, {DType::kF8E4M3, 1}, + {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, + {DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1}, + {DType::kF8E5M2FNUZ, 1}, {DType::kF8E8M0FNU, 1}, + {DType::kS16, 2}, {DType::kU16, 2}, + {DType::kF16, 2}, {DType::kBF16, 2}, + {DType::kS32, 4}, {DType::kU32, 4}, + {DType::kF32, 4}, {DType::kS64, 8}, + {DType::kU64, 8}, {DType::kF64, 8}, + {DType::kC64, 8}, {DType::kC128, 16}, + {DType::kToken, -1}, {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).byte_size(), @@ -80,34 +67,21 @@ TEST(DTypeTest, ByteSize) { TEST(DTypeTest, BitSize) { for (const auto& [kind, bit_size] : std::vector>({ - {DType::kS2, 2}, - {DType::kU2, 2}, - {DType::kS4, 4}, - {DType::kU4, 4}, - {DType::kPred, 8}, - {DType::kS8, 8}, - {DType::kU8, 8}, - {DType::kF8E3M4, 8}, - {DType::kF8E4M3, 8}, - {DType::kF8E4M3FN, 8}, - {DType::kF8E4M3B11FNUZ, 8}, - {DType::kF8E4M3FNUZ, 8}, - {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, - {DType::kS16, 16}, - {DType::kU16, 16}, - {DType::kF16, 16}, - {DType::kBF16, 16}, - {DType::kS32, 32}, - {DType::kU32, 32}, - {DType::kF32, 32}, - {DType::kS64, 64}, - {DType::kU64, 64}, - {DType::kF64, 64}, - {DType::kC64, 64}, - {DType::kC128, 128}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, + {DType::kS2, 2}, {DType::kU2, 2}, + {DType::kS4, 4}, {DType::kU4, 4}, + {DType::kPred, 8}, {DType::kS8, 8}, + {DType::kU8, 8}, {DType::kF4E2M1FN, 4}, + {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, + {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, + {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, + {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 8}, + {DType::kS16, 16}, {DType::kU16, 16}, + {DType::kF16, 16}, {DType::kBF16, 16}, + {DType::kS32, 32}, {DType::kU32, 32}, + {DType::kF32, 32}, {DType::kS64, 64}, + {DType::kU64, 64}, {DType::kF64, 64}, + {DType::kC64, 64}, {DType::kC128, 128}, + {DType::kToken, -1}, {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).bit_size(), diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index 9c581ec6227cae..2af3281a588cce 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -44,6 +44,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF4E2M1FN, xla::PrimitiveType::F4E2M1FN); CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); @@ -51,6 +52,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); + CASE(DType::kF8E8M0FNU, xla::PrimitiveType::F8E8M0FNU); CASE(DType::kF16, xla::PrimitiveType::F16); CASE(DType::kF32, xla::PrimitiveType::F32); CASE(DType::kBF16, xla::PrimitiveType::BF16); @@ -83,6 +85,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: + case xla::PrimitiveType::F4E2M1FN: case xla::PrimitiveType::F8E3M4: case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: @@ -90,6 +93,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::F8E4M3FNUZ: case xla::PrimitiveType::F8E5M2: case xla::PrimitiveType::F8E5M2FNUZ: + case xla::PrimitiveType::F8E8M0FNU: case xla::PrimitiveType::F16: case xla::PrimitiveType::F32: case xla::PrimitiveType::BF16: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 631b0bcb9b9562..45baa4abf79351 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -184,6 +184,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E3M4; @@ -205,6 +208,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; } else if (std::is_same() || !options.squash_64bit_types) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); type = primitive_util::NativeToPrimitiveType(); @@ -398,6 +404,10 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } if (dtypes.np_float8_e3m4.has_value()) { (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; @@ -415,6 +425,10 @@ absl::StatusOr DevicePut(nb::handle arg, HandleNumpyScalar; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; @@ -595,8 +609,10 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 50366be350bc08..473c082e1425cc 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -58,6 +58,7 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + std::optional float4_e2m1fn; std::optional float8_e3m4; std::optional float8_e4m3; nb_dtype float8_e4m3fn; @@ -65,6 +66,7 @@ struct CustomDtypes { nb_dtype float8_e4m3fnuz; nb_dtype float8_e5m2; nb_dtype float8_e5m2fnuz; + std::optional float8_e8m0fnu; std::optional int2; nb_dtype int4; std::optional uint2; @@ -76,6 +78,10 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { + dtypes->float4_e2m1fn = + nb_dtype::from_args(ml_dtypes.attr("float4_e2m1fn")); + } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); } @@ -91,6 +97,10 @@ const CustomDtypes& GetCustomDtypes() { nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->float8_e5m2fnuz = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->float8_e8m0fnu = + nb_dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4")); dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4")); if (nb::hasattr(ml_dtypes, "int2")) { @@ -147,6 +157,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + if (custom_dtypes.float4_e2m1fn.has_value()) { + map->emplace(*custom_dtypes.float4_e2m1fn, F4E2M1FN); + } if (custom_dtypes.float8_e3m4.has_value()) { map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); } @@ -158,6 +171,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); map->emplace(custom_dtypes.float8_e5m2, F8E5M2); map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ); + if (custom_dtypes.float8_e8m0fnu.has_value()) { + map->emplace(*custom_dtypes.float8_e8m0fnu, F8E8M0FNU); + } if (custom_dtypes.int2.has_value()) { map->emplace(*custom_dtypes.int2, S2); } @@ -217,6 +233,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F4E2M1FN: + if (custom_dtypes.float4_e2m1fn.has_value()) { + return *custom_dtypes.float4_e2m1fn; + } + break; case F8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -237,6 +258,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return custom_dtypes.float8_e5m2; case F8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case F8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case BF16: return custom_dtypes.bfloat16; case F16: @@ -307,6 +333,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF4E2M1FN: + if (custom_dtypes.float4_e2m1fn.has_value()) { + return *custom_dtypes.float4_e2m1fn; + } + break; case ifrt::DType::kF8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -327,6 +358,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return custom_dtypes.float8_e5m2; case ifrt::DType::kF8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case ifrt::DType::kF8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case ifrt::DType::kString: // PEP 3118 code for "pointer to Python Object". We use Python objects // instead of 'U' (Unicode string) or 'V' (raw data) because the latter @@ -380,6 +416,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { + dtypes->np_float4_e2m1fn = nb::object(ml_dtypes.attr("float4_e2m1fn")); + } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); } @@ -392,6 +431,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_float8_e5m2 = nb::object(ml_dtypes.attr("float8_e5m2")); dtypes->np_float8_e4m3fnuz = nb::object(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->np_float8_e5m2fnuz = nb::object(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->np_float8_e8m0fnu = nb::object(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->np_float16 = nb::object(numpy.attr("float16")); dtypes->np_float32 = nb::object(numpy.attr("float32")); dtypes->np_float64 = nb::object(numpy.attr("float64")); diff --git a/xla/python/types.h b/xla/python/types.h index aacfea1a17997f..babdf5a9bd4167 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -81,6 +81,7 @@ struct NumpyScalarTypes { nanobind::object np_uint64; nanobind::object np_bfloat16; // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. + std::optional np_float4_e2m1fn; std::optional np_float8_e3m4; std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; @@ -88,6 +89,7 @@ struct NumpyScalarTypes { nanobind::object np_float8_e4m3fnuz; nanobind::object np_float8_e5m2; nanobind::object np_float8_e5m2fnuz; + std::optional np_float8_e8m0fnu; nanobind::object np_float16; nanobind::object np_float32; nanobind::object np_float64; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 51c96229493e4c..62f04cdb7ac78c 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -204,9 +204,11 @@ NB_MODULE(xla_extension, m) { .value("U32", U32) .value("U64", U64) .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // .value("F8E3M4", F8E3M4) // .value("F8E4M3", F8E4M3) + .value("F8E8M0FNU", F8E8M0FNU) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 040c781cd087d6..c58346f7f3ca92 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -280,8 +280,12 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): bfloat16 = ml_dtypes.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# Also, it would be better to conditionally import these based on whether they +# are in the current version of ml_dtypes. +# float4_e2m1fn = ml_dtypes.float4_e2m1fn # float8_e3m4 = ml_dtypes.float8_e3m4 # float8_e4m3 = ml_dtypes.float8_e4m3 +# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -301,8 +305,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), + # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index cac63a98c1b2de..c1bb4dbc3a6fc6 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -62,8 +62,10 @@ mlir_api_version: int bfloat16: type[numpy.generic] # TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn: type[numpy.generic] # float8_e3m4: type[numpy.generic] # float8_e4m3: type[numpy.generic] +# float8_e8m0fnu: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 35b4a1ee77964f..37718e3fa87900 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -55,8 +55,10 @@ bfloat16 = xla_client.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn = xla_client.float4_e2m1fn # float8_e3m4 = xla_client.float8_e3m4 # float8_e4m3 = xla_client.float8_e4m3 +# float8_e8m0fnu = xla_client.float8_e8m0fnu float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -189,7 +191,7 @@ def TestFactory(xla_backend, fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] standard_dtypes += fp8_dtypes # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float8_e3m4, float8_e4m3] + # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 2e3862285898f2..ee7df05462f7be 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -74,6 +74,7 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F4E2M1FN: PrimitiveType F8E3M4: PrimitiveType F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType @@ -81,6 +82,7 @@ class PrimitiveType(enum.IntEnum): F8E4M3FNUZ: PrimitiveType F8E5M2: PrimitiveType F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType BF16: PrimitiveType F16: PrimitiveType F32: PrimitiveType diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 0faa9f48263989..564cb0a5cf8a0f 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -601,6 +601,10 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&s4_support); FloatSupport u4_support(U4, U8); pipeline.AddPass(&u4_support); + FloatSupport f4e2m1fn_support(F4E2M1FN, F16); + pipeline.AddPass(&f4e2m1fn_support); + FloatSupport f8e8m0fnu_support(F8E8M0FNU, F32); + pipeline.AddPass(&f8e8m0fnu_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 18841d2712dcbc..90c4f6c82e4082 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ, F8E4M3, F8E3M4 + // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN, F8E8M0FNU default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 740129585b14a0..58807e49e3a53e 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -809,6 +809,223 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, return f16_value; } +absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, + llvm::IRBuilderBase* b) { + auto i8_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt8Ty(), val); + }; + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // Cast the input value to an integer for bitwise manipulation. + // Get the absolute value of the input (discard the sign). + // f16_bits = bitcast(f16_value, int) + // f16_abs_bits = f16_bits & 0x7FFF + llvm::Value* f16_bits = b->CreateBitCast(f16_value, b->getInt16Ty()); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_bits, i16_const(0x7FFF)); + + // If the input absolute value is >= 7.0 or an infinity, the result saturates + // to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0. + // If (0 <= input <= 0.25), the result is rounded to 0.0. + // If the input is NaN, the result is undefined (implemented as minus zero). + // The rest of the cases are handled by the "happy path". + // is_overflow = f16_abs_bits >= 0x1.Cp2 + // is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows) + // is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows) + // is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold) + llvm::Value* is_overflow = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x4700)); // 7.0 + llvm::Value* is_one = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x3A00)); // 0.75 + llvm::Value* is_zero = + b->CreateICmpULE(f16_abs_bits, i16_const(0x3400)); // 0.25 + llvm::Value* is_nan = + b->CreateICmpUGT(f16_abs_bits, i16_const(0x7C00)); // inf + + // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as + // the type doesn't have Inf/NaN and can represent unbiased exponent 2). + // This case, as well as the denormal, is handled below. + TF_ASSIGN_OR_RETURN( + llvm::Value * reduced_precision, + EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1, + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1, + /*quiet_nans=*/false, b)); + + // Cast the reduced precision value to an integer for bitwise manipulation. + // Discard the least significant (9) mantissa bits leaving 1 bit. + // Truncate to + // as_int16 = bitcast(reduced_precision, int) + // as_int8 = as_int16 >> (f16_mantissa - f4_mantissa) + llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty()); + llvm::Value* as_int8 = + b->CreateTrunc(b->CreateLShr(as_int16, mantissa_diff), b->getInt8Ty()); + + // Get the sign (0 or 1). + // f4_sign = as_int8 >> 6 + llvm::Value* f4_sign = b->CreateLShr(as_int8, 6); + + // Get exponent and mantissa bits without the sign. + // Important: the mask is 0x3F (not 0x7F), discard bit #6. + // f4_bits = as_int8 & 0x3F + llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F)); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // This produces the "normal" result, i.e. not Inf or NaN or denormal. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f4_exponent_offset = bias_diff << 1; + llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(f4_exponent_offset)); + + // If the rounding resulted in zero exponent, the value is incorrect. + // This happens when the input is < 1.0 + // is_underflow = f4_normal <= 1 + llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1)); + + // Chain of selects that handles the special cases. + // f4_result = + // is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) : + // is_overflow ? (is_nan ? -0.0 : 6.0) : + // f4_normal + llvm::Value* f4_result = b->CreateSelect( + is_underflow, + // If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0 + b->CreateSelect(is_one, i8_const(0x2), + b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1))), + // If overflow, the input is >= 7.0 or infinity or NaN. + b->CreateSelect(is_overflow, + b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)), + f4_normal)); + + // Add sign to the resulting value. + // f4_signed_result = (f4_sign << 3) | f4_result + return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3)); +} + +llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilderBase* b) { + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // The input value is a 8-bit integer, extend it to 16-bit integer. + // as_int16 = bitcast(f8_value, int) + llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); + + // Get the sign and shift it to F16 position. + // f4_sign = as_int16 >> 3 + // f16_sign_bit = f4_sign << 15 + llvm::Value* f4_sign = b->CreateLShr(as_int16, 3); + llvm::Value* f16_sign_bit = b->CreateShl(f4_sign, 15); + + // Get exponent and mantissa bits without the sign. + // f4_bits = as_int16 & 0x7 + // f16_bits = f4_bits << (f16_mantissa - f4_mantissa) + llvm::Value* f4_bits = b->CreateAnd(as_int16, i16_const(0x7)); + llvm::Value* f16_bits = b->CreateShl(f4_bits, mantissa_diff); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f16_exponent_offset = bias_diff << 10; + llvm::Value* f16_normal = + b->CreateAdd(f16_bits, i16_const(f16_exponent_offset)); + + // For denormal and zero, the exponent is different. Handle these cases + // separately below. + // is_denorm_or_zero = f4_bits <= 1 + // is_zero = f4_bits == 0 + llvm::Value* is_denorm_or_zero = b->CreateICmpULE(f4_bits, i16_const(1)); + llvm::Value* is_zero = b->CreateICmpEQ(f4_bits, i16_const(0)); + + // Chain of selects that handles the special cases. + // f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal + llvm::Value* f16_result = b->CreateSelect( + is_denorm_or_zero, + b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)), + f16_normal); + + // Add sign to the resulting value. + // f16_signed_result = f16_sign_bit | f16_result + llvm::Value* f16_signed_result = b->CreateOr(f16_result, f16_sign_bit); + return b->CreateBitCast(f16_signed_result, b->getHalfTy()); +} + +llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, + llvm::IRBuilderBase* b) { + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + + // Cast the input value to an integer for bitwise manipulation. + // as_int32 = bitcast(f32_value, int) + llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); + + // Check if the input is zero, negative, overflow, infinity or NaN. + // All of these cases cannot be represented in the E8M0 format. + // is_zero_or_negative = as_int32 <= 0 + // is_overflow_or_nan = as_int32 >= 0x1.8p127 + // is_nan = is_zero_or_negative | is_overflow_or_nan + llvm::Value* is_zero_or_negative = b->CreateICmpSLE(as_int32, i32_const(0)); + llvm::Value* is_overflow_or_nan = + b->CreateICmpSGE(as_int32, i32_const(0x7F400000)); // 1.5 * 2^127 + llvm::Value* is_nan = b->CreateOr(is_zero_or_negative, is_overflow_or_nan); + + // Check if the input is a denormal which should round to the minimum value + // (2^-127), as there is no zero value. + // is_denorm = as_int32 <= 0x1p-127 + llvm::Value* is_denorm = + b->CreateICmpULE(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + + // Round the value (always up) and discard the mantissa. + // rounded = as_int32 + 0x1p-127 + // f8_normal = as_int32 >> f32_mantissa + llvm::Value* rounded = + b->CreateAdd(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + llvm::Value* f8_normal = b->CreateAShr(rounded, 23); + + // Chain of selects that handles the special cases. + // f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal) + llvm::Value* f8_result = + b->CreateSelect(is_nan, i32_const(0xFF), + b->CreateSelect(is_denorm, i32_const(0x00), f8_normal)); + + // Truncate to the result type. + return b->CreateTrunc(f8_result, b->getInt8Ty()); +} + +llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilderBase* b) { + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + + // The input value is a 8-bit integer, extend it to 32-bit integer. + // as_int32 = bitcast(f8_value, int) + llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); + + // Check if the input is a denormal or NaN. + // is_zero = as_int32 == 0x00 + // is_nan = as_int32 == 0xFF + llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0)); + llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF)); + + // Shift exponent to the left for the normal case. + // f32_normal = as_int32 << mantissa_diff + llvm::Value* f32_normal = b->CreateShl(as_int32, 23); + + // Chain of selects that handles the special cases. + // f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal) + llvm::Value* f32_result = b->CreateSelect( + is_nan, i32_const(0x7FC00000), + b->CreateSelect(is_zero, i32_const(0x400000), f32_normal)); + + // Bitcast integer bits to the result type. + return b->CreateBitCast(f32_result, b->getFloatTy()); +} + llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, @@ -903,6 +1120,18 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F4E2M1FN) { + return EmitF16ToF4e2m1fn( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } + if (to_type == F8E8M0FNU) { + return EmitF32ToF8e8m0fnu( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + b_), + b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz( F16, @@ -1108,10 +1337,29 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F4E2M1FN) { + TF_RET_CHECK(to_type != F4E2M1FN); + operand_value = EmitF4e2m1fnToF16(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F8E8M0FNU) { + TF_RET_CHECK(to_type != F8E8M0FNU); + operand_value = EmitF8e8m0fnuToF32(operand_value, b_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { TF_RET_CHECK(to_type != from_type); PrimitiveType cast_type = primitive_util::IsFloatingPointType(to_type) ? to_type : F16; + if (to_type == F8E8M0FNU || to_type == F4E2M1FN) { + cast_type = F32; + } TF_ASSIGN_OR_RETURN(operand_value, EmitF8fnuzToFloating(from_type, operand_value, cast_type, b_, module_)); @@ -1184,6 +1432,24 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } + if (to_type == F4E2M1FN) { + // Cast to F16 first. Casts to F4E2M1FN must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); + } + return EmitF16ToF4e2m1fn(operand_value, b_); + } + if (to_type == F8E8M0FNU) { + // Cast to F32 first. Casts to F8E8M0FNU must be from F32. + if (from_type != F32) { + operand_value = b_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(F32, module_->getContext())); + } + return EmitF32ToF8e8m0fnu(operand_value, b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } @@ -1734,6 +2000,12 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); + } else if (operand_type == F4E2M1FN) { + lhs_value = EmitF4e2m1fnToF16(lhs_value, b_); + rhs_value = EmitF4e2m1fnToF16(rhs_value, b_); + } else if (operand_type == F8E8M0FNU) { + lhs_value = EmitF8e8m0fnuToF32(lhs_value, b_); + rhs_value = EmitF8e8m0fnuToF32(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { TF_ASSIGN_OR_RETURN( lhs_value, @@ -3588,10 +3860,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; - if (component_element_type == F8E4M3FNUZ) { - float_ir_type = - llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); - } else if (component_element_type == F8E5M2FNUZ) { + if (component_element_type == F8E4M3FNUZ || + component_element_type == F8E5M2FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); } else { diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 71847a88ca518a..b3f4b8ddef8949 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -99,9 +99,10 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -613,7 +614,9 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -628,6 +631,10 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { auto tname = this->TypeName(); + if (std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } const auto hlo_text = absl::StrReplaceAll(R"( HloModule matmul diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index 4afb96362cf86e..e0be95da5f6680 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -40,6 +40,8 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F4E2M1FN: + return &llvm::APFloat::Float4E2M1FN(); case F8E3M4: return &llvm::APFloat::Float8E3M4(); case F8E4M3: @@ -54,6 +56,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( return &llvm::APFloat::Float8E5M2(); case F8E5M2FNUZ: return &llvm::APFloat::Float8E5M2FNUZ(); + case F8E8M0FNU: + return &llvm::APFloat::Float8E8M0FNU(); case BF16: return &llvm::APFloat::BFloat(); case F16: @@ -72,6 +76,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, PrimitiveType type) { switch (type) { + case F4E2M1FN: + return b->getIntNTy(4); case F8E3M4: case F8E4M3: case F8E4M3B11FNUZ: @@ -79,6 +85,7 @@ absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, case F8E4M3FNUZ: case F8E5M2: case F8E5M2FNUZ: + case F8E8M0FNU: return b->getInt8Ty(); case BF16: return b->getBFloatTy(); @@ -649,8 +656,14 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign); // Bitwise or the sign bit back in. - sign = b->CreateZExt(sign, output_int_type); - sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type)); + int shift = output_type_bit_width - BitWidth(input_type); + if (shift >= 0) { + sign = b->CreateZExt(sign, output_int_type); + sign = b->CreateShl(sign, shift); + } else { + sign = b->CreateLShr(sign, -shift); + sign = b->CreateTrunc(sign, output_int_type); + } llvm::Value* result = b->CreateOr(sign, result_abs); // Bitcast to the output type. diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index 5d0c696ccc9807..897bc03d783151 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -550,9 +550,18 @@ INSTANTIATE_TEST_SUITE_P( using ReduceTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; +static absl::string_view init_value(PrimitiveType dtype) { + if (dtype == C64 || dtype == C128) { + return "(0, 0)"; + } else if (dtype == F8E8M0FNU) { + return "1e-40"; + } else { + return "0"; + } +} + TEST_P(ReduceTest, IsTritonSupportedReduction) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -567,7 +576,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -599,7 +608,6 @@ TEST_P( ReduceTest, UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -614,7 +622,7 @@ ENTRY triton_computation { ROOT reduce = $0[2] reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -624,7 +632,6 @@ ENTRY triton_computation { TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -638,7 +645,7 @@ ENTRY triton_computation { constant_0 = $0[] constant($1) ROOT reduce = $0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -649,7 +656,6 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( add { @@ -670,7 +676,7 @@ ENTRY triton_computation { dimensions={1}, to_apply=add ROOT reduce = $0[125] get-tuple-element(tuple), index=0 })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -701,7 +707,6 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( custom_call { @@ -716,7 +721,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -740,7 +745,6 @@ using ReductionComputationTest = // computation and in regular HLO. See triton_support.cc for more details. TEST_P(ReductionComputationTest, DifferentBinaryOps) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute( R"( reduce_computation { @@ -755,7 +759,7 @@ ENTRY triton_computation { ROOT reduce = $0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=reduce_computation })", - "$0", HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0"); + "$0", HloOpcodeString(opcode), init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -1115,13 +1119,12 @@ TEST_P(ConstantTest, ConstantEffectiveScalar) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $0[1,1] constant({{$1}}) })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, @@ -1133,13 +1136,12 @@ TEST_P(ConstantTest, Constant2D) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $0[3,3] constant({{$1,$1,$1},{$1,$1,$1},{$1,$1,$1}}) })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + "$0", init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index faeaa7a6c46679..666c187998cb63 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1478,6 +1478,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); const GpuFloatSupport s4_support(gpu_version, S4, S8); const GpuFloatSupport u4_support(gpu_version, U4, U8); + const GpuFloatSupport f4e2m1fn_support(gpu_version, F4E2M1FN, F16); + const GpuFloatSupport f8e8m0fnu_support(gpu_version, F8E8M0FNU, F32); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1491,6 +1493,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e3m4_support); sub_pipeline.AddPass(&s4_support); sub_pipeline.AddPass(&u4_support); + sub_pipeline.AddPass(&f4e2m1fn_support); + sub_pipeline.AddPass(&f8e8m0fnu_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 16383324dfb016..6e0e14e320a7f9 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,9 +29,10 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f8e5m2", "f8e5m2fnuz", "f8e4m3", - "f8e4m3fn", "f8e4m3fnuz", - "f8e4m3b11fnuz", "f8e3m4")); + "f4e2m1fn", "f8e5m2", "f8e5m2fnuz", + "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4", + "f8e8m0fnu")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 88823f1dd9e5c1..38dfd05667e009 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2972,9 +2972,10 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); if (instruction->opcode() == HloOpcode::kConvert || instruction->opcode() == HloOpcode::kCompare || + instruction->opcode() == HloOpcode::kIsFinite || (instruction->opcode() == HloOpcode::kSelect && operand_shape.element_type() == PRED)) { - // Convert and Compare instructions can change element_size_in_bits + // Some instructions can change element_size_in_bits // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index d56172dd4b254a..b937dbc1500b69 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -199,6 +199,8 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case S16: case U16: return llvm::Type::getInt16Ty(context); + case F4E2M1FN: + return llvm::Type::getIntNTy(context, 4); case F8E5M2: case F8E5M2FNUZ: case F8E4M3: @@ -206,6 +208,7 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case F8E4M3B11FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(context); case BF16: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index f5246389e485c3..e3e7d1f17e312f 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,10 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF4E2M1FN; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E3M4; }; @@ -61,6 +65,10 @@ struct ToDataType { static constexpr DataType value = DataType::kF8E5M2FNUZ; }; template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E8M0FNU; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kFloat; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 6b7a87d80b3aec..24851e56d75eda 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -69,12 +69,14 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6aee86bf2cbc19..182af599af9e5c 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -56,6 +56,10 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E4M3FNUZ; case PrimitiveType::F8E3M4: return DataType::kF8E3M4; + case PrimitiveType::F4E2M1FN: + return DataType::kF4E2M1FN; + case PrimitiveType::F8E8M0FNU: + return DataType::kF8E8M0FNU; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -93,6 +97,10 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E4M3FNUZ; case DataType::kF8E3M4: return PrimitiveType::F8E3M4; + case DataType::kF4E2M1FN: + return PrimitiveType::F4E2M1FN; + case DataType::kF8E8M0FNU: + return PrimitiveType::F8E8M0FNU; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -154,6 +162,8 @@ absl::StatusOr GetBlasComputationType( case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through case PrimitiveType::F8E3M4: // fall-through + case PrimitiveType::F4E2M1FN: // fall-through + case PrimitiveType::F8E8M0FNU: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index e5730121addd8d..8864476bf0d825 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -39,8 +39,10 @@ hipDataType AsHipblasDataType(blas::DataType type) { case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: - LOG(FATAL) - << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; + case blas::DataType::kF4E2M1FN: + case blas::DataType::kF8E8M0FNU: + LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " + "F8E3M4, F4E2M1FN and F8E8M0FNU"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 63d64c4be70c77..f82288cb8b1046 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -863,12 +863,14 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:types", + "//xla:util", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@ml_dtypes//:float8", "@tsl//tsl/platform:ml_dtypes", ] + if_rocm_is_configured([ diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index c12ce79a06e8fa..f2fb97be51f68d 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include #include "absl/base/casts.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -47,6 +48,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" +#include "xla/util.h" #include "tsl/platform/ml_dtypes.h" #if TENSORFLOW_USE_ROCM @@ -93,6 +95,20 @@ std::pair, std::vector> AllSignedPairs( return {xs, ys}; } +template +void AddNegativeValuesMaybeRemoveZero(std::vector& values) { + values.reserve(values.size() * 2); + if (!has_zero_v) { + values.erase(values.begin()); + } + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto neg = -values[i]; + if (SignAndMagnitude(neg).first) { + values.push_back(neg); + } + } +} + class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: static constexpr float kEpsF32 = std::numeric_limits::epsilon(); @@ -1371,14 +1387,7 @@ class TotalOrderTest : public ClientLibraryTestBase { values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); } #endif - values.reserve(values.size() * 2); - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); std::vector lhs_data; std::vector rhs_data; lhs_data.reserve(values.size() * values.size()); @@ -1423,19 +1432,24 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types; +#if !defined(XLA_TEST_BACKEND_TPU) + // TODO(b/385004399): Run tests on these types on TPU. + tsl::float4_e2m1fn, tsl::float8_e8m0fnu, +#endif + float>; TYPED_TEST_SUITE(TotalOrderTest, Types); @@ -1462,13 +1476,7 @@ TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { if constexpr (std::numeric_limits::has_infinity) { values.push_back(std::numeric_limits::infinity()); } - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); auto lhs = ConstantR1(&builder, values); auto rhs = ConstantR1( &builder, diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 9650077ed57b28..9e191a30b405ae 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -52,7 +52,13 @@ using FloatTypes = ::testing::Types; + tsl::float8_e5m2fnuz +#ifndef XLA_TEST_BACKEND_TPU + // TODO(b/385004399): Run tests on these types on TPU. + , + tsl::float4_e2m1fn, tsl::float8_e8m0fnu +#endif + >; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 4f06ea0cc290c7..a8e370ad50c0d3 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,9 +54,17 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -741,10 +749,11 @@ XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { XlaBuilder builder(this->TestName()); using FP = TypeParam; - auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.5}, FP{2.0}, FP{-0.0}}); ConvertElementType(a, PRED); - std::array expected = {false, true, true, false}; + bool zero_pred = !has_zero_v; + std::array expected = {zero_pred, true, true, zero_pred}; this->template ComputeAndCompareR1(&builder, expected, {}); } @@ -1925,5 +1934,283 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F4E2M1FN + +XLA_TEST_F(ConvertTest, DISABLED_ON_TPU(ConvertF16F4e2m1fnRoundtrip)) { + // Convert from FP16 to FP4, then back to FP16. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFCp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.8p-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.004p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FCp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f4 = + ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, + DISABLED_ON_TPU(DISABLED_ON_CPU(ConvertF32F4e2m1fnRoundtrip))) { + // Convert from FP32 to FP4, then back to FP32. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFFFFEp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.8p-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.000002p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FFFFEp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f4 = ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive)) { + // Convert from FP4 to supported floating point type, then back to FP4. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f4_as_fp = + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f4_as_fp, F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive2)) { + // Convert from supported floating point type to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive3)) { + // Convert from FP4 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, + DISABLED_ON_TPU(ConvertF4e2m1fnF16RoundtripExhaustive4)) { + // Convert from (B)F16 to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +// ----- F8E8M0FNU + +XLA_TEST_F(ConvertTest, DISABLED_ON_TPU(ConvertF32F8e8m0fnuRoundtrip)) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, nan}, // No zero values + {-0.0, nan}, + {1.0, 1.0}, + {-1.0, nan}, // No negative values + {nan, nan}, + {inf, nan}, + // clang-format on + {0x1.8p1, 0x1p2}, // Round-to-even up + {0x1.8p2, 0x1p3}, // Round-to-even up (always rounds up) + {0x1p127, 0x1p127}, // Max value + {0x1.7FFFFEp127, 0x1p127}, // Largest number that doesn't overflow + {0x1.8p127, nan}, // Smallest number that overflows + {0x1.FFFFFEp127, nan}, // Overflow + {0x1p-126, 0x1p-126}, // Smallest F8 normal + {0x0.800002p-126, 0x1p-126}, // Smallest number rounding up to normal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E8M0FNU); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive)) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive2)) { + if (this->client_->platform()->Name() == "Host") { + // This test is disabled on CPU, as converting 0x1p-127 from double to float + // using CVTSD2SS on x64 results in an underflow (even though the result is + // representable as denormalized float32). + if (std::is_same_v) { + GTEST_SKIP() << "Skipping test for double precision floating point that " + "loses denormal value during conversion"; + } + } + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive3)) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, + DISABLED_ON_TPU(ConvertF8e8m0fnuF16RoundtripExhaustive4)) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 7f0d9c4507a2a2..d1d6882b6532a5 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -121,6 +121,7 @@ enum PrimitiveType { F64, C64, C128, + F4E2M1FN, F8E5M2, F8E4M3, F8E4M3FN, @@ -128,17 +129,19 @@ enum PrimitiveType { F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, + F8E8M0FNU, }; const std::vector& primitive_strings() { static auto vec = new std::vector( - {"s1", "s2", "s4", "s8", - "s16", "s32", "s64", "u1", - "u2", "u4", "u8", "u16", - "u32", "u64", "f16", "bf16", - "f32", "f64", "c64", "c128", - "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz", "f8e3m4"}); + {"s1", "s2", "s4", "s8", + "s16", "s32", "s64", "u1", + "u2", "u4", "u8", "u16", + "u32", "u64", "f16", "bf16", + "f32", "f64", "c64", "c128", + "f4e2m1fn", "f8e3m4", "f8e4m3", "f8e4m3b11fnuz", + "f8e4m3fn", "f8e4m3fnuz", "f8e5m2", "f8e5m2fnuz", + "f8e8m0fnu"}); return *vec; } @@ -415,6 +418,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F64: return FillFloatT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -422,6 +426,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: @@ -475,6 +480,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F64: return DisplayT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -482,6 +488,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/BUILD b/xla/tsl/framework/BUILD index 9f9a030d868e79..9185021d405501 100644 --- a/xla/tsl/framework/BUILD +++ b/xla/tsl/framework/BUILD @@ -339,6 +339,7 @@ cc_library( ]), deps = [ ":numeric_types", + "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:types", ], ) diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index f7a9bd7a54bc91..2292ee563db80c 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "xla/tsl/framework/numeric_types.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/types.h" namespace tsl { @@ -70,13 +71,15 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 2ac31005c16629..4a6d8fff6f72cd 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -24,6 +24,8 @@ enum DataType { kInt64 = 12; kF8E4M3 = 13; kF8E3M4 = 14; + kF4E2M1FN = 15; + kF8E8M0FNU = 16; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index e2c5eb295c6b12..a986efb7cca963 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,8 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float4_e2m1fn = + py::dtype::from_args(ml_dtypes.attr("float4_e2m1fn")).num(); numpy_dtypes.float8_e3m4 = py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); numpy_dtypes.float8_e4m3 = @@ -75,6 +77,8 @@ struct MlDtypesInitInfo { py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")).num(); numpy_dtypes.float8_e5m2fnuz = py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")).num(); + numpy_dtypes.float8_e8m0fnu = + py::dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")).num(); numpy_dtypes.int4 = py::dtype::from_args(ml_dtypes.attr("int4")).num(); numpy_dtypes.uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")).num(); } catch (const std::exception& e) { @@ -85,6 +89,7 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float4_e2m1fn == NPY_NOTYPE || numpy_dtypes.float8_e3m4 == NPY_NOTYPE || numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || @@ -92,6 +97,7 @@ struct MlDtypesInitInfo { numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || numpy_dtypes.float8_e5m2 == NPY_NOTYPE || numpy_dtypes.float8_e5m2fnuz == NPY_NOTYPE || + numpy_dtypes.float8_e8m0fnu == NPY_NOTYPE || numpy_dtypes.int4 == NPY_NOTYPE || numpy_dtypes.uint4 == NPY_NOTYPE) { init_valid = false; } diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index b3aa94e430239a..725d844c27bb4e 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,7 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float4_e2m1fn; int float8_e3m4; int float8_e4m3; int float8_e4m3fn; @@ -31,6 +32,7 @@ struct NumpyDtypes { int float8_e4m3fnuz; int float8_e5m2; int float8_e5m2fnuz; + int float8_e8m0fnu; int int4; int uint4; }; diff --git a/xla/types.h b/xla/types.h index 98e3d7c9331ffc..b702404601dae7 100644 --- a/xla/types.h +++ b/xla/types.h @@ -131,16 +131,32 @@ struct make_specialized_signed>> { template using make_specialized_signed_t = typename make_specialized_signed::type; +// has_negative_zero[_v] + template struct has_negative_zero : std::bool_constant::is_iec559> {}; +template <> +struct has_negative_zero : std::bool_constant {}; + template <> struct has_negative_zero : std::bool_constant {}; template inline constexpr bool has_negative_zero_v = has_negative_zero::value; +// has_zero[_v] + +template +struct has_zero : std::bool_constant {}; + +template <> +struct has_zero : std::bool_constant {}; + +template +inline constexpr bool has_zero_v = has_zero::value; + } // namespace xla #endif // XLA_TYPES_H_ diff --git a/xla/util.cc b/xla/util.cc index c18435a04c64bf..023e09342f113b 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -148,6 +148,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { + static_assert(std::numeric_limits::has_quiet_NaN); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, @@ -174,6 +175,10 @@ static std::string GenericRoundTripFpToString(FloatT value) { static_cast(value)); } +std::string RoundTripFpToString(tsl::float4_e2m1fn value) { + return GenericRoundTripFpToString(value); +} + std::string RoundTripFpToString(tsl::float8_e5m2 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); @@ -212,6 +217,11 @@ std::string RoundTripFpToString(tsl::float8_e3m4 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e8m0fnu value) { + std::string result = GenericRoundTripFpToString(value); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index 959009073e96f9..a4578709392445 100644 --- a/xla/util.h +++ b/xla/util.h @@ -416,6 +416,9 @@ std::string VectorString(const std::initializer_list& c) { return VectorString>(c); } +// Returns a string which can losslessly round trip to a float4 E2M1FN. +std::string RoundTripFpToString(tsl::float4_e2m1fn value); + // Returns a string which can losslessly round trip to a float8 E5M2. std::string RoundTripFpToString(tsl::float8_e5m2 value); @@ -437,6 +440,9 @@ std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); // Returns a string which can losslessly round trip to a float8 E3M4. std::string RoundTripFpToString(tsl::float8_e3m4 value); +// Returns a string which can losslessly round trip to a float8 E8M0FNU. +std::string RoundTripFpToString(tsl::float8_e8m0fnu value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -652,8 +658,9 @@ template auto SignAndMagnitude(T x) { using BitType = UnsignedIntegerTypeForSizeType; BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); - const BitType x_bits = Eigen::numext::bit_cast(x); - const BitType x_sign = x_bits ^ x_abs_bits; + // Eigen implements the sign value to be either all-zeros (for positive input) + // or all-ones (for negative input). + BitType x_sign = Eigen::numext::bit_cast(Eigen::numext::signbit(x)); if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. @@ -664,12 +671,17 @@ auto SignAndMagnitude(T x) { return std::make_pair(x_sign, x_abs_bits); } +template <> +inline auto SignAndMagnitude(tsl::float8_e8m0fnu x) { + uint8_t x_bits = Eigen::numext::bit_cast(x); + return std::make_pair(static_cast(0), x_bits); +} + template auto SignAndMagnitudeToTwosComplement(T sign, T magnitude) { static_assert(!std::numeric_limits::is_signed); using SignedType = std::make_signed_t; - return static_cast(magnitude) ^ - (static_cast(sign) < 0 ? SignedType{-1} : SignedType{0}); + return static_cast(magnitude) ^ static_cast(sign); } // Returns the signed magnitude of T. @@ -679,6 +691,11 @@ auto ToSignMagnitude(T input) { return SignAndMagnitudeToTwosComplement(sign, magnitude); } +template <> +inline auto ToSignMagnitude(tsl::float8_e8m0fnu input) { + return Eigen::numext::bit_cast(input); +} + template constexpr int NanPayloadBits() { // Floating point types with signaling NaNs have payloads. diff --git a/xla/util_test.cc b/xla/util_test.cc index cc2465099c1d98..f864b3215aa4af 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -206,9 +206,9 @@ namespace { template void TotalOrderHelper(T x, T y) { auto x_sm = ToSignMagnitude(x); - bool x_sign = static_cast(Eigen::numext::signbit(x)); - bool y_sign = static_cast(Eigen::numext::signbit(y)); auto y_sm = ToSignMagnitude(y); + bool x_sign = static_cast(SignAndMagnitude(x).first); + bool y_sign = static_cast(SignAndMagnitude(y).first); if (x_sign && !y_sign) { EXPECT_LT(x_sm, y_sm) << x << " " << y; } @@ -239,6 +239,18 @@ void TotalOrderHelper(T x, T y) { } } // namespace +TEST(UtilTest, TotalOrder_F4E2M1FN) { + for (int a = 0; a < 16; ++a) { + tsl::float4_e2m1fn x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 16; ++b) { + tsl::float4_e2m1fn y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E5M2) { for (int a = 0; a < 256; ++a) { tsl::float8_e5m2 x = @@ -325,6 +337,18 @@ TEST(UtilTest, TotalOrder_F8E3M4) { } } +TEST(UtilTest, TotalOrder_F8E8M0FNU) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e8m0fnu x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e8m0fnu y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 82b822f2e3ecb9..87a4b3b35c049c 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -111,6 +111,17 @@ enum PrimitiveType { F8E5M2FNUZ = 24; F8E4M3FNUZ = 25; + // MX float dtypes, as described in: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + // + // F4E2M1FN has 2 exponent bits and 1 mantissa bit. + // F8E8M0FNU has 8 exponent bits, no mantissa and no sign. + // + // Only finite values are supported (hence "FN" suffix). Unlike IEEE types, + // infinities and NaNs are not supported. + F4E2M1FN = 32; + F8E8M0FNU = 33; + // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. C128 = 18; // Paired F64 (real, imag), as in std::complex. @@ -136,7 +147,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 32 + // Next = 34 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -581,15 +592,17 @@ message LiteralProto { bytes bf16s = 13; bytes u16s = 16; bytes s16s = 17; - bytes f8e5m2s = 19; - bytes f8e4m3s = 28; - bytes f8e4m3fns = 20; + bytes f4e2m1fns = 30; + bytes f8e3m4s = 29; bytes f8e4m3b11fnuzs = 23; - bytes f8e5m2fnuzs = 24; + bytes f8e4m3fns = 20; bytes f8e4m3fnuzs = 25; - bytes f8e3m4s = 29; + bytes f8e4m3s = 28; + bytes f8e5m2fnuzs = 24; + bytes f8e5m2s = 19; + bytes f8e8m0fnus = 31; repeated int64 sparse_indices = 14; - // Next = 30 + // Next = 32 } message WindowDimension {