From d7e00c49a4b4f26c06266d6bb941275e67464c01 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 13 Jan 2025 20:12:29 +0000 Subject: [PATCH] Add F4E2M1FN and F8E8M0FNU types --- xla/array2d_test.cc | 28 ++ .../codegen/transforms/expand_float_ops.cc | 191 ++++++----- .../gpu/codegen/transforms/lower_tensors.cc | 59 ++-- .../transforms/tests/expand_float_ops.mlir | 50 +++ .../transforms/tests/lower_tensors.mlir | 42 ++- xla/comparison_util.h | 9 +- xla/ffi/api/api.h | 4 + xla/ffi/api/c_api.h | 2 + xla/ffi/api/ffi.h | 6 + xla/ffi/api/ffi_test.cc | 6 + xla/ffi/call_frame.cc | 2 + xla/fp_util_test.cc | 70 +++++ xla/hlo/builder/lib/math.cc | 11 +- xla/hlo/builder/lib/math_test.cc | 32 +- xla/hlo/evaluator/BUILD | 1 + xla/hlo/evaluator/hlo_evaluator.cc | 2 +- .../evaluator/hlo_evaluator_typed_visitor.h | 2 + .../hlo_evaluator_typed_visitor_mxfloat.cc | 23 ++ .../expanders/comparison_expander.cc | 59 ++-- .../simplifiers/float_normalization.cc | 3 + .../simplifiers/float_normalization_test.cc | 4 +- xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc | 20 ++ .../translate/hlo_to_mhlo/tests/import.hlo | 20 +- .../translate/mhlo_to_hlo/literal_exporter.cc | 6 + .../translate/mhlo_to_hlo/tests/export.mlir | 18 +- xla/literal.cc | 28 +- xla/literal.h | 29 +- xla/literal_comparison.cc | 7 +- xla/literal_comparison_test.cc | 52 +-- xla/literal_test.cc | 75 +++-- xla/mlir/utils/type_util.cc | 10 +- xla/mlir/utils/type_util_test.cc | 2 + xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 14 + xla/pjrt/c/CHANGELOG.md | 2 + xla/pjrt/c/pjrt_c_api.h | 6 +- xla/pjrt/c/pjrt_c_api_helpers.cc | 8 + xla/primitive_util.cc | 12 + xla/primitive_util.h | 80 ++++- xla/primitive_util_test.cc | 134 +++++++- xla/python/ifrt/dtype.cc | 8 + xla/python/ifrt/dtype.h | 6 +- xla/python/ifrt/dtype.proto | 6 + xla/python/ifrt/dtype_test.cc | 86 ++--- xla/python/pjrt_ifrt/pjrt_dtype.cc | 4 + xla/python/py_values.cc | 16 + xla/python/types.cc | 42 +++ xla/python/types.h | 2 + xla/python/xla.cc | 2 + xla/python/xla_client.py | 6 + xla/python/xla_client.pyi | 2 + xla/python/xla_client_test.py | 4 +- xla/python/xla_extension/__init__.pyi | 2 + xla/service/cpu/cpu_compiler.cc | 4 + xla/service/cpu/onednn_memory_util.h | 2 +- xla/service/elemental_ir_emitter.cc | 278 +++++++++++++++- xla/service/elemental_ir_emitter_test.cc | 15 +- xla/service/float8_fnuz_ir_emitter.cc | 17 +- .../gpu/fusions/triton/triton_support_test.cc | 55 ++-- xla/service/gpu/gpu_compiler.cc | 4 + .../gpu/tests/float_conversions_test.cc | 7 +- xla/service/hlo_verifier.cc | 3 +- xla/service/llvm_ir/llvm_util.cc | 3 + xla/stream_executor/data_type.h | 8 + xla/stream_executor/dnn.cc | 2 + xla/stream_executor/gpu/gpu_blas_lt.cc | 10 + xla/stream_executor/rocm/hip_blas_utils.cc | 6 +- xla/tests/array_elementwise_ops_test.cc | 50 +-- xla/tests/constants_test.cc | 8 +- xla/tests/convert_test.cc | 297 +++++++++++++++++- xla/tools/driver.cc | 21 +- xla/tsl/framework/type_traits.h | 4 +- xla/tsl/protobuf/dnn.proto | 2 + xla/tsl/python/lib/core/ml_dtypes.cc | 6 + xla/tsl/python/lib/core/ml_dtypes.h | 2 + xla/types.h | 16 + xla/util.cc | 10 + xla/util.h | 25 +- xla/util_test.cc | 28 +- xla/xla_data.proto | 27 +- 79 files changed, 1848 insertions(+), 377 deletions(-) create mode 100644 xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 055a6e77420819..b28e98e990f20c 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF4E2M1FN) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + +TEST(Array2dTest, LinspaceF8E8M0FNU) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 2.0); // 1.5 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 4.0); // 3.0 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc index 81cb99d66f82d9..ff2ce862277980 100644 --- a/xla/backends/gpu/codegen/transforms/expand_float_ops.cc +++ b/xla/backends/gpu/codegen/transforms/expand_float_ops.cc @@ -163,7 +163,13 @@ int GetSignificandBits(mlir::FloatType ty) { } int GetExponentBias(mlir::FloatType ty) { - return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()); + return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) - + ty.isFloat8E8M0FNU(); // No zero exponent for E8M0. +} + +bool IsFNUZ(mlir::FloatType ty) { + return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E5M2FNUZ(); } Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -175,7 +181,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { return b.create(ma::CmpFPredicate::OEQ, value, inf); } - assert(ty.getIntOrFloatBitWidth() == 8); + assert(ty.getIntOrFloatBitWidth() <= 8); // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. if (ty.isFloat8E5M2()) { Val bits{b.create(b.getI8Type(), value), &b}; @@ -196,6 +202,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { if (mlir::LLVM::isCompatibleOuterType(ty)) { return b.create(ma::CmpFPredicate::UNO, value, value); } + if (ty.isFloat4E2M1FN()) { + return b.create(false, b.getI1Type()); + } assert(ty.getIntOrFloatBitWidth() == 8); Val bits{b.create(b.getI8Type(), value), &b}; @@ -207,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { return (bits & 0b0111'1111) == 0b0111'1111; } else if (ty.isFloat8E3M4()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); + } else if (ty.isFloat8E8M0FNU()) { + return bits == 0xFF; } return bits == 0x80; } @@ -281,11 +292,18 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth()); mlir::IntegerType wide_int_ty; - if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) { + if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) { wide_int_ty = b.getI16Type(); } else { wide_int_ty = b.getIntegerType( std::max(from_int_ty.getWidth(), to_int_ty.getWidth())); + // Avoid overflow for bit shifts. + auto may_overflow = [&](mlir::Type a, mlir::Type b) { + return a.isFloat8E8M0FNU() && b.isF16(); + }; + if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) { + wide_int_ty = b.getI32Type(); + } } auto convert_int = [&](mlir::Type ty, Value v) -> Val { if (v.getType() == ty) { @@ -300,34 +318,49 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, int64_t exp_offset = to_bias - from_bias; int digit_shift = to_mantissa - from_mantissa; - Val from_bits{ - b.create( - b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value), - &b}; + int from_width = value.getType().getIntOrFloatBitWidth(); + Val from_bits{b.create(b.getIntegerType(from_width), value), + &b}; + if (from_width < 8) { + from_bits = convert_int(b.getIntegerType(8), from_bits); + } auto cst = [&](mlir::Type ty, int64_t n) -> Val { return {b.create(n, ty), &b}; }; // Shift bits to destination type, without sign bit. - Val from_sign_bit = - from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0; - - from_bits = - from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); - - Value result_is_inf = IsInf(value, b); - Value input_is_nan = IsNaN(value, b); + Val from_sign_bit; + if (!from_ty.isFloat8E8M0FNU()) { + from_sign_bit = from_bits.shrui(from_width - 1) != 0; + from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); + } auto cst_bits = [&](llvm::APFloat f) { return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())), f.bitcastToAPInt().getZExtValue()); }; - Value to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); - Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); - Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + Value to_nan; + Value to_inf; + Val to_zero; + + // MX float types have neither infinities nor NaNs. + if (to_ty.isFloat4E2M1FN()) { + to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + to_nan = to_zero | 0x8; + to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); + } else if (to_ty.isFloat8E8M0FNU()) { + to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); + to_inf = to_nan; + to_zero = Val{to_nan, &b}; + } else { + to_inf = cst_bits(llvm::APFloat::getInf(to_ty.getFloatSemantics())); + to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); + to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); + } - auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) { + auto round_bits_to_nearest_even = [&](Val bits, Val roundoff, + bool use_implicit_bit = false) { assert(bits.value.getType() == roundoff.value.getType()); // Round to nearest even by adding a bias term. // Consider a bit pattern @@ -337,9 +370,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // - L is 1, R is 1, OR // - L is 0, R is 1, any T is one. // We do this by adding L to a bit pattern consisting of all T = 1. - Val rounded = (bits.shrui(roundoff) & 1) + - (bits.MakeConstant(1).shl(roundoff - 1) - 1); - Val bias{b.create(roundoff == 0, roundoff, rounded), &b}; + Val bias = !use_implicit_bit + ? (bits.shrui(roundoff) & 1) + + (bits.MakeConstant(1).shl(roundoff - 1) - 1) + : bits.MakeConstant(1).shl(roundoff - 1); return bits + bias; }; @@ -349,9 +383,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // Round the mantissa if it is shrinking. Val rounded_from_bits = convert_int(wide_int_ty, from_bits); if (digit_shift < 0) { - rounded_from_bits = round_bits_to_nearest_even( - from_bits, from_bits.MakeConstant(-digit_shift)) & - ~((1ll << (-digit_shift)) - 1); + rounded_from_bits = + round_bits_to_nearest_even( + rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0) & + ~((1ll << (-digit_shift)) - 1); } // Re-bias the exponent. @@ -394,10 +430,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Val bits = convert_int(wide_int_ty, from_bits); // Determine exponent in target type. - Value normalization_factor = - convert_int(i32_ty, - b.create(from_bits)) - - (from_int_ty.getWidth() - from_mantissa - 1); + Value clz = convert_int( + i32_ty, b.create(from_bits)); + Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz; + Value normalization_factor = cst(i32_ty, from_mantissa) - msb; Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor; // If the result is subnormal, adjust the subnormal bits to account for @@ -418,10 +454,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0); bits.value = b.create(biased_exp_sle_zero, subnormal_bits, normal_bits); - if (digit_shift > 0) { + if (digit_shift >= 0) { bits = bits.shl(digit_shift); } else { - bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift)); + bits = round_bits_to_nearest_even( + bits, bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0); bits = bits.shrui(-digit_shift); } bits = convert_int(to_int_ty, bits); @@ -430,11 +468,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, } else if (to_min_exp > from_min_exp) { // `To` supports fewer exponents near zero which means that some values in // `From` may become subnormal. - Val unbiased_exp = biased_from_exp - from_bias; - Val biased_to_exp = unbiased_exp + to_bias; + Val biased_to_exp = biased_from_exp + (to_bias - from_bias); // Subnormals and zero. // Round and shift mantissa down. - Val from_has_leading_one = biased_from_exp != 0; + Val from_has_leading_one = + !from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1); Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one); from_has_leading_one = convert_int(from_int_ty, from_has_leading_one); Val exponent_shift_i32 = @@ -469,31 +507,35 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, result); } - // Handle types with no unsigned zero. - auto is_nuz = [](mlir::FloatType ty) { - return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E5M2FNUZ(); - }; + Value result_is_inf = IsInf(value, b); + Value input_is_nan = IsNaN(value, b); - if (is_nuz(to_ty)) { + if (to_ty.isFloat8E8M0FNU()) { + // Converting a negative number to E8M0 results in NaN. + input_is_nan = from_sign_bit | input_is_nan; + } else if (IsFNUZ(to_ty)) { // Clear the sign bit if the result is zero (the output has no negative - // zero). - Val result_is_non_zero = Val{result, &b} != 0; + // zero). Handle the edge case when the input is zero and the result is not. + Val result_is_non_zero = + (digit_shift > 0 ? from_bits : Val{result, &b}) != 0; from_sign_bit = from_sign_bit & result_is_non_zero; - } else if (is_nuz(from_ty)) { + } else if (IsFNUZ(from_ty)) { // Clear the sign bit if the input is NaN (it's positive but encoded as // negative 0). from_sign_bit = from_sign_bit ^ input_is_nan; } + if (!from_ty.isFloat8E8M0FNU()) { + result = b.create(from_bits == 0, to_zero, result); + } result = b.create(result_is_inf, to_inf, result); - result = b.create(from_bits == 0, to_zero, result); result = b.create(input_is_nan, to_nan, result); - Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); - // Insert sign bit. - result = b.create(from_sign_bit, neg_result, result); + if (!from_ty.isFloat8E8M0FNU()) { + Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); + result = b.create(from_sign_bit, neg_result, result); + } result = b.create(to_ty, result); return result; } @@ -506,8 +548,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (dst_ty.getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit truncf"); + if (dst_ty.getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -524,8 +566,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (src.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit extf"); + if (src.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -544,8 +586,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { auto lhs = mlir::cast(op.getLhs()); auto rhs = mlir::cast(op.getRhs()); - if (lhs.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf"); + if (lhs.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -553,16 +595,16 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics()); if (op.getPredicate() == ma::CmpFPredicate::UNE && mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) { - Val int_value{b.create(rewriter.getI8Type(), lhs), &b}; + mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth()); + Val int_value{b.create(int_ty, lhs), &b}; int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. - if (rhs_cst.isZero() && - (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || - lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { - int_value = int_value & 0x7f; - constant &= 0x7f; + if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) { + int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1; + int_value = int_value & mask; + constant &= mask; } - auto cst = b.create(constant, rewriter.getI8Type()); + auto cst = b.create(constant, int_ty); rewriter.replaceOpWithNewOp(op, ma::CmpIPredicate::ne, int_value, cst); return mlir::success(); @@ -586,18 +628,23 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern { auto src = mlir::cast(op.getOperand()); // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16. // Once that's removed, remove the code for BF16 here. - if (src.getType().getWidth() != 8 && !src.getType().isBF16()) { - return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf"); + if (src.getType().getWidth() > 8 && !src.getType().isBF16()) { + return rewriter.notifyMatchFailure(op, + "not an f8 (or less) or bf16 absf"); } + + // If type is unsigned (E8M0), the operation is no-op. + if (!llvm::APFloat::semanticsHasSignedRepr( + src.getType().getFloatSemantics())) { + rewriter.replaceAllOpUsesWith(op, op.getOperand()); + return mlir::success(); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); Val value{b.create(i_ty, src), &b}; - if (src.getType().getWidth() == 8) { - value = value & 0x7f; - } else { - CHECK(src.getType().isBF16()); - value = value & 0x7fff; - } + int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1; + value = value & mask; rewriter.replaceOpWithNewOp(op, src.getType(), value); return mlir::success(); } @@ -609,8 +656,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 itofp"); + if (op.getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp"); } Value to_float = rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); @@ -625,8 +672,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getIn().getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 fptoi"); + if (op.getIn().getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi"); } Value to_f32 = rewriter.create( op.getLoc(), rewriter.getF32Type(), op.getIn()); diff --git a/xla/backends/gpu/codegen/transforms/lower_tensors.cc b/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 0fff3bc811bbca..c7c7ae83d21bb4 100644 --- a/xla/backends/gpu/codegen/transforms/lower_tensors.cc +++ b/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -297,7 +297,8 @@ std::tuple GetI4IndexAndNibble(Value linear_index, mlir::LLVM::GEPOp CreateGep(TypedValue tensor, Value linear_index, mlir::ImplicitLocOpBuilder& b) { Type element_type = tensor.getType().getElementType(); - if (element_type == b.getI4Type()) { + if (element_type.isIntOrFloat() && + element_type.getIntOrFloatBitWidth() == 4) { element_type = b.getI8Type(); } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); @@ -326,7 +327,8 @@ struct RewriteTensorExtract : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); Value is_low_nibble = nullptr; - if (element_type == rewriter.getI4Type()) { + if (element_type.isIntOrFloat() && + element_type.getIntOrFloatBitWidth() == 4) { std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); } @@ -341,7 +343,7 @@ struct RewriteTensorExtract : OpRewritePattern { auto high_value = b.create( load, b.create(4, load.getType())); load = b.create( - op.getType(), + rewriter.getI4Type(), b.create(is_low_nibble, load, high_value)); } @@ -377,6 +379,7 @@ struct RewriteTransferRead : OpRewritePattern { auto source = mlir::dyn_cast>( op.getSource()); + mlir::Type source_element_type = source.getType().getElementType(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); @@ -385,7 +388,9 @@ struct RewriteTransferRead : OpRewritePattern { if (vector_type.getElementType().isInteger(1)) { vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type()); } - if (op.getVectorType().getElementType().isInteger(4)) { + mlir::Type gep_element_type = vector_type.getElementType(); + if (gep_element_type.isIntOrFloat() && + gep_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern { auto loaded = b.create(llvm_vector_type, gep).getResult(); - if (source.getType().getElementType().isInteger(1)) { + if (source_element_type.isInteger(1)) { Value zero = b.create( mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0))); loaded = b.create(arith::CmpIPredicate::ne, loaded, zero); - } else if (source.getType().getElementType().isInteger(4)) { + } else if (source_element_type.isIntOrFloat() && + source_element_type.getIntOrFloatBitWidth() == 4) { // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the // elements. loaded = PermutePairsInVector(loaded, b); @@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern { auto scalar_value = op.getScalar(); // For i4 we store 2 values into one byte. This needs special handling here. - if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) { + if (tensor_dest.getType().getElementType().isIntOrFloat() && + tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) { // We need to use directly op.getDest() as input, otherwise the following // rewrite might remove the only user of it. tensor_dest = op.getDest(); @@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern { auto tensor_dest_i8 = b.create(tensor_ty, tensor_dest) .getResult(0); + if (scalar_value.getType() != rewriter.getI4Type()) { + scalar_value = + b.create(rewriter.getI4Type(), scalar_value); + } scalar_value = b.create(ty, scalar_value); // We need AtomicRMWOp because it can happen that different threads try to @@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::Value vector_value = op.getVector(); - if (op.getVectorType().getElementType().isInteger(1)) { + mlir::Type vector_element_type = op.getVectorType().getElementType(); + if (vector_element_type.isInteger(1)) { vector_value = b.create( op.getVectorType().cloneWith(std::nullopt, b.getI8Type()), vector_value); } - if (op.getVectorType().getElementType().isInteger(4)) { + if (vector_element_type.isIntOrFloat() && + vector_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -577,21 +590,19 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - if (mlir::isa(element_type)) { - int bit_width = mlir::cast(element_type).getWidth(); - if (bit_width == 4) { - num_elements = CeilOfRatio(num_elements, 2); - llvm_element_type = b.getI8Type(); - auto unpacked_data = - mlir::cast(value).getRawData(); - std::vector packed_data(num_elements); - absl::Span packed_data_span = - absl::MakeSpan(packed_data.data(), packed_data.size()); - PackIntN(4, unpacked_data, packed_data_span); - value = mlir::DenseElementsAttr::getFromRawBuffer( - mlir::RankedTensorType::get({num_elements}, llvm_element_type), - packed_data); - } + if (element_type.isIntOrFloat() && + element_type.getIntOrFloatBitWidth() == 4) { + num_elements = CeilOfRatio(num_elements, 2); + llvm_element_type = b.getI8Type(); + auto unpacked_data = + mlir::cast(value).getRawData(); + std::vector packed_data(num_elements); + absl::Span packed_data_span = + absl::MakeSpan(packed_data.data(), packed_data.size()); + PackIntN(4, unpacked_data, packed_data_span); + value = mlir::DenseElementsAttr::getFromRawBuffer( + mlir::RankedTensorType::get({num_elements}, llvm_element_type), + packed_data); } auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements); diff --git a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir index 442fe5e9291572..dea8988d474b05 100644 --- a/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir @@ -115,3 +115,53 @@ module { // CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32 // CHECK: arith.truncf %[[EXT]] : f32 to f16 // CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 { + %ret = arith.extf %arg0 : f4E2M1FN to f16 + return %ret : f16 + } +} + +// CHECK-LABEL: @f4_to_f16 +// CHECK-NOT: arith.extf + +// ----- + +module { + func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN { + %ret = arith.truncf %arg0 : f16 to f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f16_to_f4 +// CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN { + %ret = math.absf %arg0 : f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f4_abs +// CHECK-NOT: math.absf +// CHECK: arith.constant 7 : i4 + +// ----- + +module { + func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU { + %ret = math.absf %arg0 : f8E8M0FNU + return %ret : f8E8M0FNU + } +} + +// CHECK-LABEL: @e8m0_abs +// CHECK-NOT: math.absf +// CHECK: return %arg0 diff --git a/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir b/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir index 646c7a00ff756f..864f68d1da6f49 100644 --- a/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir @@ -763,4 +763,44 @@ func.func @for_op(%arg0: tensor<500xf32>) -> f32 { // CHECK-LABEL: @for_op // CHECK: scf.for {{.*}} -> (vector<4xf32>) { -// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { \ No newline at end of file +// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { + +// ----- + +func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN { + %cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN> + %extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN> + %0 = arith.addf %extracted, %extracted_0 : f4E2M1FN + return %0 : f4E2M1FN +} +// CHECK: llvm.mlir.global private constant +// CHECK-SAME: dense<[25, 64]> +// CHECK-LABEL: @f4_constant + +// ----- + +func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0.0 : f4E2M1FN + %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN> + func.return %out : vector<2xf4E2M1FN> +} +// CHECK-LABEL: @transfer_read_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8] +// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4> +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN> +// CHECK: return %[[OUT]] : vector<2xf4E2M1FN> + +// ----- + +func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}, + %arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> { + %c10 = arith.constant 10 : index + %out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN> + func.return %out : tensor<43xf4E2M1FN> +} +// CHECK-LABEL: @transfer_write_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8 +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4> +// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr diff --git a/xla/comparison_util.h b/xla/comparison_util.h index 5a21595da4d741..44f0dd48640bb1 100644 --- a/xla/comparison_util.h +++ b/xla/comparison_util.h @@ -193,8 +193,13 @@ class Comparison { // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN // Reference: // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations - using R = SignedIntegerTypeForSizeType; - return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + if constexpr (std::numeric_limits::is_signed) { + using R = SignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } else { + using R = UnsignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } } } // Applies the comparison from this Comparison's direction and ordering. diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index 389d2d2a9a7aec..9787476f8f7eac 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "C128"; case XLA_FFI_DataType_TOKEN: return os << "TOKEN"; + case XLA_FFI_DataType_F4E2M1FN: + return os << "F4E2M1FN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; case XLA_FFI_DataType_F8E3M4: @@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "F8E5M2FNUZ"; case XLA_FFI_DataType_F8E4M3FNUZ: return os << "F8E4M3FNUZ"; + case XLA_FFI_DataType_F8E8M0FNU: + return os << "F8E8M0FNU"; } } diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index 8d6f1095fad24a..bf8cb7d1a8ad19 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -201,6 +201,8 @@ typedef enum { XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, XLA_FFI_DataType_F8E4M3FNUZ = 25, + XLA_FFI_DataType_F4E2M1FN = 32, + XLA_FFI_DataType_F8E8M0FNU = 33, } XLA_FFI_DataType; // LINT.ThenChange(ffi_test.cc) diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index f264451da34735..aeeab1d505ab66 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -79,6 +79,8 @@ enum class DataType : uint8_t { F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, F8E3M4 = XLA_FFI_DataType_F8E3M4, + F4E2M1FN = XLA_FFI_DataType_F4E2M1FN, + F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; inline constexpr DataType F8E3M4 = DataType::F8E3M4; +inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN; +inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: case DataType::F8E3M4: + case DataType::F4E2M1FN: + case DataType::F8E8M0FNU: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index f09588b9e986a2..3c51a0966ae02e 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); + EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); @@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); + EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU)); } TEST(FfiTest, DataTypeByteWidth) { @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128), ByteWidth(DataType::C128)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN), + ByteWidth(DataType::F4E2M1FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), @@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E4M3FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), ByteWidth(DataType::F8E3M4)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU), + ByteWidth(DataType::F8E8M0FNU)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 3fb2ac3c7786fa..7bcb14da445e8c 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -264,6 +264,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C64: case PrimitiveType::C128: case PrimitiveType::TOKEN: + case PrimitiveType::F4E2M1FN: case PrimitiveType::F8E5M2: case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: @@ -271,6 +272,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: case PrimitiveType::F8E3M4: + case PrimitiveType::F8E8M0FNU: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 3eb3561a264d40..5de6f4e33a2018 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -119,6 +119,76 @@ class FP8E4M3DistanceTest : public ::testing::Test {}; using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F4E2M1FNDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)), + 1); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)), + 2); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 4); +} + +TEST(FPDistanceTest, F8E8M0FNUDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)), + 0); + + // one step apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)), + 1); + + // two steps apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)), + 2); +} + TEST(FPDistanceTest, F8E3M4Distance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index 3a72875d2733de..b39619de18cc72 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -185,6 +185,7 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F4E2M1FN: case F8E3M4: case F8E4M3: case F8E5M2: @@ -972,8 +973,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1025,8 +1027,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index cf56e0e39cf2b0..e468282f7c3e0c 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -96,9 +96,13 @@ class MathTypedTest : public MathTest { Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); bool has_inf = std::numeric_limits::has_infinity; + bool has_nan = std::numeric_limits::has_quiet_NaN; + bool has_finite = !has_inf && !has_nan; + bool has_nan_only = !has_inf && has_nan; + auto expected = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1( - {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1({true, true, true, true, true, has_finite, + has_finite, has_finite, has_finite}), LiteralUtil::CreateR1({false, false, false, false, false, has_inf, has_inf, false, false}), LiteralUtil::CreateR1( @@ -106,7 +110,8 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1( {false, false, false, false, false, false, has_inf, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, - !has_inf, !has_inf, true, true})); + has_nan_only, has_nan_only, has_nan, + has_nan})); ComputeAndCompareLiteral(&b, expected, {}); } @@ -119,10 +124,11 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), &b)); + bool is_mx = std::is_same_v; ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( - {has_negative_zero_v, false, false, false, false, false, false}), + {has_negative_zero_v, false, false, false, false, false, is_mx}), {}, error_spec_); } @@ -137,6 +143,9 @@ class MathTypedTest : public MathTest { // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { SetFastMathDisabled(true); + if (std::is_same_v) { + GTEST_SKIP() << "Skipping due to low precision"; + } // Tests disable constant folding by default, but this test needs it // enabled, otherwise we don't tickle the bug we're trying to catch. @@ -182,9 +191,14 @@ class MathTypedTest : public MathTest { &b); Erf(x); - bool has_inf = std::numeric_limits::has_infinity; - std::vector expected = { - has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)}; + bool inf_as_nan = !std::numeric_limits::has_infinity && + std::numeric_limits::has_quiet_NaN; + std::vector expected = {inf_as_nan ? nan : T(-1), + inf_as_nan ? nan : T(1), + T(-0), + T(0), + T(-1), + T(1)}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } @@ -202,6 +216,10 @@ using TestTypes = #endif #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 double, +#endif +#ifndef XLA_TEST_BACKEND_TPU + // TODO(b/385004399): Run tests on these types on TPU. + tsl::float4_e2m1fn, #endif float>; diff --git a/xla/hlo/evaluator/BUILD b/xla/hlo/evaluator/BUILD index b9a3c348ef30be..a6d20a8eeafd13 100644 --- a/xla/hlo/evaluator/BUILD +++ b/xla/hlo/evaluator/BUILD @@ -37,6 +37,7 @@ cc_library( "hlo_evaluator_typed_visitor_int4.cc", "hlo_evaluator_typed_visitor_int64.cc", "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_mxfloat.cc", "hlo_evaluator_typed_visitor_uint16.cc", "hlo_evaluator_typed_visitor_uint32.cc", "hlo_evaluator_typed_visitor_uint64.cc", diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 35fac878f104da..8e44243823c097 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -3722,7 +3722,7 @@ absl::StatusOr StochasticConvertOp(const Literal& operand_literal, const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { - bool is_negative = static_cast(Eigen::numext::signbit(operand)); + bool is_negative = static_cast(SignAndMagnitude(operand).first); if (Eigen::numext::isinf(operand)) { return is_negative ? std::numeric_limits::min() : std::numeric_limits::max(); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 74feab55e5e9c8..8499b0ab7107dc 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1736,6 +1736,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; @@ -1743,6 +1744,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc new file mode 100644 index 00000000000000..6bc96c1a1f7cda --- /dev/null +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc @@ -0,0 +1,23 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "tsl/platform/ml_dtypes.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/xla/hlo/transforms/expanders/comparison_expander.cc b/xla/hlo/transforms/expanders/comparison_expander.cc index 61a4305b09d5b9..7218d076aab0fe 100644 --- a/xla/hlo/transforms/expanders/comparison_expander.cc +++ b/xla/hlo/transforms/expanders/comparison_expander.cc @@ -118,34 +118,41 @@ absl::StatusOr ComparisonExpander::ExpandInstruction( ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); - PrimitiveType signed_type = - primitive_util::SignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); - - auto zero_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); - - auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MinValue(signed_shape.element_type()))); - min_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, min_value, {})); - - auto max_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - max_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, max_value, {})); - - lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, - min_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, - min_value, max_value); + if (compare_type != F8E8M0FNU) { + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(signed_type))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); + } else { + auto int8_shape = ShapeUtil::ChangeElementType(lhs->shape(), U8); + lhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, lhs)); + rhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, rhs)); + } auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( - instruction->shape(), lhs, rhs, compare->direction(), - Comparison::Type::kSigned)); + instruction->shape(), lhs, rhs, compare->direction())); VLOG(2) << "New comparison instruction for total order:" << new_compare->ToString(); diff --git a/xla/hlo/transforms/simplifiers/float_normalization.cc b/xla/hlo/transforms/simplifiers/float_normalization.cc index 88dbd2781ca60f..cf978bf581fcde 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -217,6 +217,9 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) { if (subshape->element_type() == from) { subshape->set_element_type(to); + if (subshape->has_layout() && from == F4E2M1FN) { + subshape->mutable_layout()->set_element_size_in_bits(0); + } } }); float_normalization_->UpdateLayout(hlo->mutable_shape()); diff --git a/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/xla/hlo/transforms/simplifiers/float_normalization_test.cc index 86ec889abc6527..b614f74229c0e5 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -150,7 +150,9 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); + ::testing::Values(F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, + F8E5M2, F8E5M2FNUZ, F8E8M0FNU)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index f70769ea91abec..cea1bc583ea56e 100644 --- a/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -69,6 +70,25 @@ ::mlir::DenseElementsAttr CreateDenseAttrFromLiteral( } return ::mlir::DenseElementsAttr::getFromRawBuffer(type, packed_padded_data); + } else if constexpr (std::is_same_v) { + // DenseElementsAttr::get() does not support being passed an array of + // tsl::float4_e2m1fn. So convert each element to APFloat first. + auto data_span = literal.data(); + std::vector apfloats; + apfloats.reserve(literal.element_count()); + for (size_t i = 0; i < literal.element_count(); i++) { + llvm::APFloat apfloat{static_cast(data_span[i])}; + bool losesInfo; + llvm::APFloat::opStatus status = + apfloat.convert(llvm::APFloat::Float4E2M1FN(), + llvm::APFloat::rmNearestTiesToEven, &losesInfo); + CHECK_EQ(status, llvm::APFloat::opOK) + << "Failed to convert " << data_span[i] << " to Float4E2M1FN APFloat"; + CHECK(!losesInfo) << "Lost info when converting " << data_span[i] + << " to Float4E2M1FN APFloat"; + apfloats.push_back(apfloat); + } + return ::mlir::DenseElementsAttr::get(type, apfloats); } else { auto data_span = literal.data(); return ::mlir::DenseElementsAttr::get( diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 3a1e7ceabb160f..577e4ad61f89e2 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -421,6 +421,12 @@ add { // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + %constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_15:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + %constant.15 = f8e8m0fnu[4] constant({1, 2, 4, 8}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -542,7 +548,19 @@ add { %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> - ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) + %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) + + // CHECK-NEXT: %14 = mhlo.convert %13 : (tensor<4xf32>) -> tensor<4xf4E2M1FN> + %convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16) + + // CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32> + %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17) + + // CHECK-NEXT: %16 = mhlo.convert %15 : (tensor<4xf32>) -> tensor<4xf8E8M0FNU> + %convert.19 = f8e8m0fnu[4] convert(f32[4] %convert.18) + + // CHECK-NEXT: %17 = mhlo.convert %16 : (tensor<4xf8E8M0FNU>) -> tensor<4xf32> + ROOT %convert.20 = f32[4] convert(f8e8m0fnu[4] %convert.19) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc index 821f1487cf88c1..f50e2a097a3277 100644 --- a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc +++ b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc @@ -41,6 +41,12 @@ xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { xla::Array array(shape.dimensions()); if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { array.SetValues(dense_attr.getValues()); + } else if constexpr (xla::primitive_util::IsMXType(type)) { + // Bitcast MX floating point types from APFloat. + auto values = dense_attr.getValues(); + for (int i = 0; i < values.size(); i++) { + array.data()[i] = T::FromRep(values[i].bitcastToAPInt().getZExtValue()); + } } else { // The only way to get subbyte integers from getValues() is to get them as // APInts. diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index a22ec331d93b20..c017751477cb51 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -606,6 +606,12 @@ func.func @main() { // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + // CHECK: f4e2m1fn[4] constant({1, 2, 3, 4}) + %cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + + // CHECK: f8e8m0fnu[4] constant({1, 2, 4, 8}) + %cst_19 = arith.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + func.return } @@ -739,7 +745,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> - func.return %11 : tensor<2xf32> + %12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN> + %13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32> + %14 = "mhlo.convert"(%13) : (tensor<2xf32>) -> tensor<2xf8E8M0FNU> + %15 = "mhlo.convert"(%14) : (tensor<2xf8E8M0FNU>) -> tensor<2xf32> + func.return %15 : tensor<2xf32> } // CHECK: ENTRY @@ -755,7 +765,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) // CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) // CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) -// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) +// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) +// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]]) +// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) +// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]]) +// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 6b5db7f893ec4c..cf2a1492b60683 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -92,6 +92,7 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { proto.s64s_size() || !proto.u1s().empty() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || + !proto.f4e2m1fns().empty() || !proto.f8e8m0fnus().empty() || !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || @@ -1875,7 +1876,6 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(size_bytes_dense(), other.size_bytes_dense()); if (primitive_util::IsSubByteNonPredType(subshape().element_type())) { - CHECK(!primitive_util::IsFloatingPointType(subshape().element_type())); auto one_array = buffer(); auto two_array = other.buffer(); const int bits_per_element = @@ -2268,6 +2268,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case F4E2M1FN: + *proto->mutable_f4e2m1fns() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E5M2: *proto->mutable_f8e5m2s() = std::string( reinterpret_cast(data().data()), @@ -2303,6 +2308,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E8M0FNU: + *proto->mutable_f8e8m0fnus() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2454,6 +2464,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case F4E2M1FN: { + const std::string& s(proto.f4e2m1fns()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float4_e2m1fn) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E5M2: { const std::string& s(proto.f8e5m2s()); TF_RET_CHECK(data().size() * sizeof(tsl::float8_e5m2) == @@ -2507,6 +2525,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E8M0FNU: { + const std::string& s(proto.f8e8m0fnus()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e8m0fnu) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal.h b/xla/literal.h index 1b76f2effe6a94..0c1985488cd2bb 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -600,18 +600,17 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { - static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte = 0; for (int b = 0; b < elements_per_byte; ++b) { - uint8_t src = - static_cast(elements[i * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = Eigen::numext::bit_cast( + elements[i * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -620,9 +619,9 @@ class LiteralBase { if (rest != 0) { uint8_t byte = 0; for (int64_t b = 0; b < rest; ++b) { - uint8_t src = - static_cast(elements[bytes * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = Eigen::numext::bit_cast( + elements[bytes * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -712,11 +711,17 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { - static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + constexpr auto cast = [](uint8_t x) -> NativeT { + if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { + return Eigen::numext::bit_cast(x); + } + return static_cast(x); + }; + + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte; @@ -725,7 +730,7 @@ class LiteralBase { } for (int b = 0; b < elements_per_byte; ++b) { elements[i * elements_per_byte + b] = - static_cast(byte & LsbMask(bits_per_element)); + cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } @@ -737,7 +742,7 @@ class LiteralBase { } for (int64_t b = 0; b < rest; ++b) { elements[bytes * elements_per_byte + b] = - static_cast(byte & LsbMask(bits_per_element)); + cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index c97629594122bb..ecea5024963934 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -206,8 +206,8 @@ template std::string FpValueToString(NativeT value) { if constexpr (is_specialized_floating_point_v) { constexpr int kPrecisionDigits = std::numeric_limits::max_digits10; - const int kExponentDigts = - std::ceil(std::log10(std::numeric_limits::max_exponent10)); + const int kExponentDigts = std::ceil( + std::log10(std::max(std::numeric_limits::max_exponent10, 1))); constexpr int kExtraChars = 4; const int kTotalChars = kPrecisionDigits * kExponentDigts + kExtraChars; return absl::StrFormat("%*.*g", kTotalChars, kPrecisionDigits, @@ -418,6 +418,9 @@ class NearComparator { } else { float_distance = CalculateFloatDistance(expected, actual); abs_error = FpAbsoluteValue(actual - expected); + if (!std::numeric_limits::is_signed && IsNaN(abs_error)) { + abs_error = FpAbsoluteValue(expected - actual); + } // Avoid division by 0 even though it's well-defined because ubsan can be // configured to treat this as a fatal error. diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 4dcdad85fd5d43..ad0d6f2b70dfac 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -30,13 +30,15 @@ template class LiteralComparisonTest : public ::testing::Test {}; using TestedTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + auto expected = LiteralUtil::CreateR0(TypeParam(1.0)); TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); @@ -44,12 +46,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = 9.0; // F8E4M3* - if (type == F8E5M2) - expV = 10.0; + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + float expV = 1.125; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 1.25; else if (type == F8E3M4) - expV = 8.5; + expV = 1.0625; + else if (type == F4E2M1FN) + expV = 1.5; + else if (type == F8E8M0FNU) + expV = 2.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, @@ -64,12 +70,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = 12.0; // F8E4M3* - if (type == F8E5M2) - expV = 14.0; + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + float expV = 1.5; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 2.0; else if (type == F8E3M4) - expV = 10.0; + expV = 1.25; + else if (type == F4E2M1FN) + expV = 4.0; + else if (type == F8E8M0FNU) + expV = 16.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; @@ -86,12 +96,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(8.0); - float expV = 12.1; // F8E4M3* - if (type == F8E5M2) - expV = 13.0; + auto actual = LiteralUtil::CreateR0(1.0); + float expV = 1.51; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 2.01; else if (type == F8E3M4) - expV = 10.125; + expV = 1.26; + else if (type == F4E2M1FN) + expV = 4.1; + else if (type == F8E8M0FNU) + expV = 16.5; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 5bbddd572c8a64..60ba2985307816 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -124,11 +124,11 @@ class LiteralUtilTest : public ::testing::Test { template class LiteralUtilFloatTest : public LiteralUtilTest {}; -using FloatTypes = - ::testing::Types; +using FloatTypes = ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -187,6 +187,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { LiteralUtil::CreateR0(static_cast(9.001f)); EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); + auto f4e2m1fn_lit = + LiteralUtil::CreateR0(tsl::float4_e2m1fn(0.5)); + EXPECT_EQ("f4e2m1fn[] 0.5", f4e2m1fn_lit.ToString()); + auto f8e5m2_lit = LiteralUtil::CreateR0(tsl::float8_e5m2(0.5)); EXPECT_EQ("f8e5m2[] 0.5", f8e5m2_lit.ToString()); @@ -219,6 +223,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e3m4_lit = LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); + + auto f8e8m0fnu_lit = + LiteralUtil::CreateR0(tsl::float8_e8m0fnu(0.5)); + EXPECT_EQ("f8e8m0fnu[] 0.5", f8e8m0fnu_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -671,6 +679,11 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); + tsl::float4_e2m1fn m16(4); + EXPECT_TRUE(LiteralUtil::CreateR1({m16}).IsAll(4)); + // 5 rounds to 4 in E2M1FN but is not equal to 4, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({m16}).IsAll(5)); + tsl::float8_e5m2 p16(8); EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false @@ -701,6 +714,11 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + tsl::float8_e8m0fnu w16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({w16}).IsAll(8)); + // 9 rounds to 8 in E8M0FNU but is not equal to 8, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({w16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2226,6 +2244,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + using e2m1 = tsl::float4_e2m1fn; + auto vector_f4e2m1fn = + LiteralUtil::CreateR1({e2m1{1.0}, e2m1{2.0}, e2m1{-3.0}}); using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); @@ -2246,6 +2267,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); using e3 = tsl::float8_e3m4; auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); + using e8m0 = tsl::float8_e8m0fnu; + auto vector_f8e8m0fnu = + LiteralUtil::CreateR1({e8m0{1.0}, e8m0{2.0}, e8m0{4.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2266,13 +2290,15 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); - EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f4e2m1fn, to_from_proto(vector_f4e2m1fn)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); - EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); - EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); - EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); + EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e8m0fnu, to_from_proto(vector_f8e8m0fnu)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2523,19 +2549,19 @@ TEST_F(LiteralUtilTest, SliceOnBool) { } TEST_F(LiteralUtilTest, IsEqualAt) { - double val_double = 10.0; - int val_integral = 10; - Literal c1 = LiteralUtil::CreateR0(10); + double val_double = 4.0; + int val_integral = 4; + Literal c1 = LiteralUtil::CreateR0(val_integral); EXPECT_TRUE(c1.IsEqualAt({}, val_double)); EXPECT_TRUE(c1.IsEqualAt({}, val_integral)); - Literal c2 = LiteralUtil::CreateR0(10); + Literal c2 = LiteralUtil::CreateR0(val_double); EXPECT_TRUE(c2.IsEqualAt({}, val_double)); EXPECT_TRUE(c2.IsEqualAt({}, val_integral)); Literal c3 = LiteralUtil::CreateR0(tsl::float8_e5m2{val_double}); EXPECT_TRUE(c3.IsEqualAt({}, val_double)); EXPECT_TRUE(c3.IsEqualAt({}, val_integral)); - complex128 val_complex = {10, 0}; + complex128 val_complex = {val_double, 0}; EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); EXPECT_TRUE(c3.IsEqualAt({}, val_complex)); @@ -2544,8 +2570,8 @@ TEST_F(LiteralUtilTest, IsEqualAt) { EXPECT_TRUE(c4.IsEqualAt({}, val_integral)); EXPECT_TRUE(c4.IsEqualAt({}, val_complex)); EXPECT_FALSE(c4.IsEqualAt({}, std::numeric_limits::infinity())); - complex128 val_true_complex = {10, 3}; - complex64 val_smaller_complex = {10, 3}; + complex128 val_true_complex = {val_double, 3}; + complex64 val_smaller_complex = {static_cast(val_double), 3}; Literal c5 = LiteralUtil::CreateR0(val_true_complex); EXPECT_TRUE(c5.IsEqualAt({}, val_true_complex)); EXPECT_TRUE(c5.IsEqualAt({}, val_smaller_complex)); @@ -2569,6 +2595,14 @@ TEST_F(LiteralUtilTest, IsEqualAt) { LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); EXPECT_TRUE(c10.IsEqualAt({}, val_double)); EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); + Literal c11 = + LiteralUtil::CreateR0(tsl::float4_e2m1fn{val_double}); + EXPECT_TRUE(c11.IsEqualAt({}, val_double)); + EXPECT_TRUE(c11.IsEqualAt({}, val_integral)); + Literal c12 = LiteralUtil::CreateR0( + tsl::float8_e8m0fnu{val_double}); + EXPECT_TRUE(c12.IsEqualAt({}, val_double)); + EXPECT_TRUE(c12.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2894,10 +2928,11 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F8E8M0FNU, + C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 2581390a1e13d7..ea8da4d4990d9d 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -32,6 +32,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( switch (type) { case xla::PrimitiveType::PRED: return b.getI1Type(); + case xla::PrimitiveType::F4E2M1FN: + return b.getFloat4E2M1FNType(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); case xla::PrimitiveType::F8E4M3: @@ -46,6 +48,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E4M3FNUZType(); case xla::PrimitiveType::F8E3M4: return b.getFloat8E3M4Type(); + case xla::PrimitiveType::F8E8M0FNU: + return b.getFloat8E8M0FNUType(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -78,7 +82,9 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( } xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { - if (type.isFloat8E5M2()) { + if (type.isFloat4E2M1FN()) { + return xla::PrimitiveType::F4E2M1FN; + } else if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; } else if (type.isFloat8E4M3()) { return xla::PrimitiveType::F8E4M3; @@ -92,6 +98,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E5M2FNUZ; } else if (type.isFloat8E3M4()) { return xla::PrimitiveType::F8E3M4; + } else if (type.isFloat8E8M0FNU()) { + return xla::PrimitiveType::F8E8M0FNU; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index a8043ab0b5f140..2239943d906b7b 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P( Execute, TypeUtilTest, ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, + {F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, @@ -111,6 +112,7 @@ INSTANTIATE_TEST_SUITE_P( {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, + {F8E8M0FNU, [](mlir::Builder b) { return b.getFloat8E8M0FNUType(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index d07a178c6c4e7f..16a64cdc22b768 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e3m4(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor @@ -6872,6 +6879,13 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- +func.func @f8e8m0fnu(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @top_k_1d(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index d56741eb3500b0..6b84e2766bc44e 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,4 +1,6 @@ # PJRT C API changelog +## 0.62 +* Added types F4E2M1FN and F8E8M0FNU. ## 0.61 * Added ``PJRT_KeyValueTryGet`` to the KV store interface, diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index f2fc3b1c507a3c..c27e8d7dd97a2e 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 61 +#define PJRT_API_MINOR 62 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -717,6 +717,10 @@ typedef enum { // More truncated 8 bit floating-point formats. PJRT_Buffer_Type_F8E4M3, PJRT_Buffer_Type_F8E3M4, + PJRT_Buffer_Type_F8E8M0FNU, + + // 4-bit MX floating-point format. + PJRT_Buffer_Type_F4E2M1FN, } PJRT_Buffer_Type; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index c5113d1766ef66..1a697b0a4cae9f 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -310,6 +310,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_BF16; case xla::PrimitiveType::F64: return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; + case xla::PrimitiveType::F4E2M1FN: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; case xla::PrimitiveType::F8E4M3: @@ -324,6 +326,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; case xla::PrimitiveType::F8E3M4: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; + case xla::PrimitiveType::F8E8M0FNU: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -377,6 +381,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C64; case PJRT_Buffer_Type::PJRT_Buffer_Type_C128: return xla::PrimitiveType::C128; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: @@ -391,6 +397,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E4M3FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: return xla::PrimitiveType::F8E3M4; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.cc b/xla/primitive_util.cc index f09b9b7a1edb50..8dcc2376ef2e77 100644 --- a/xla/primitive_util.cc +++ b/xla/primitive_util.cc @@ -94,6 +94,18 @@ bool HasInfinity(PrimitiveType type) { return false; } +bool HasNaN(PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + return FloatingPointTypeSwitch( + [&](auto constant_type) -> bool { + return std::numeric_limits< + NativeTypeOf>::has_quiet_NaN; + }, + type); + } + return false; +} + bool HasNegativeZero(PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { return FloatingPointTypeSwitch( diff --git a/xla/primitive_util.h b/xla/primitive_util.h index b9c1c978bc620e..70a8335c8bc518 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -69,6 +69,9 @@ int ExponentBias(PrimitiveType type); // Returns whether the type has a value for infinity. bool HasInfinity(PrimitiveType type); +// Returns whether the type has a value for NaN. +bool HasNaN(PrimitiveType type); + // Returns whether the type has a value for negative zero. bool HasNegativeZero(PrimitiveType type); @@ -185,6 +188,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return BF16; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F4E2M1FN; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; @@ -220,6 +228,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E3M4; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E8M0FNU; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -334,6 +347,11 @@ struct PrimitiveTypeToNative { using type = bfloat16; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float4_e2m1fn; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; @@ -369,6 +387,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e3m4; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e8m0fnu; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -401,6 +424,10 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { primitive_type < PrimitiveType_ARRAYSIZE; } +constexpr bool IsMXType(PrimitiveType type) { + return type == F4E2M1FN || type == F8E8M0FNU; +} + constexpr bool IsF8Type(PrimitiveType type) { return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || @@ -409,7 +436,7 @@ constexpr bool IsF8Type(PrimitiveType type) { constexpr bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16 || - IsF8Type(type); + IsF8Type(type) || IsMXType(type); } constexpr bool IsComplexType(PrimitiveType type) { @@ -473,6 +500,9 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F4E2M1FN: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E3M4: return std::forward(f)( PrimitiveTypeConstant()); @@ -494,6 +524,9 @@ constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { case F8E5M2FNUZ: return std::forward(f)( PrimitiveTypeConstant()); + case F8E8M0FNU: + return std::forward(f)( + PrimitiveTypeConstant()); case F16: return std::forward(f)(PrimitiveTypeConstant()); case BF16: @@ -577,6 +610,9 @@ inline constexpr int PrimitiveTypeBitWidth() { if constexpr (primitive_type == PRED) { return std::numeric_limits::digits; } + if constexpr (IsMXType(primitive_type)) { + return NativeT::kBits; + } if constexpr (IsFloatingPointType(primitive_type)) { return sizeof(NativeT) * std::numeric_limits::digits; } @@ -715,6 +751,10 @@ inline bool CastPreservesValues(PrimitiveType from_type, if (from_type == to_type) { return true; } + // * -> F8E8M0FNU is not possible because zero cannot be represented. + if (to_type == F8E8M0FNU) { + return false; + } // PRED -> * if (from_type == PRED) { return true; @@ -737,21 +777,33 @@ inline bool CastPreservesValues(PrimitiveType from_type, return false; } // F -> F is safe if the exponent/significand are preserved and `to_type` - // preserves infinities in `from_type. + // preserves infinities/nans/unsigned zero in `from_type`. if (primitive_util::IsFloatingPointType(from_type) && primitive_util::IsFloatingPointType(to_type)) { - return (!primitive_util::HasInfinity(from_type) || - primitive_util::HasInfinity(to_type)) && - primitive_util::SignificandWidth(from_type) <= - primitive_util::SignificandWidth(to_type) && - primitive_util::ExponentWidth(from_type) <= - primitive_util::ExponentWidth(to_type) && - (primitive_util::UnderflowExponent(from_type) - - primitive_util::SignificandWidth(from_type)) >= - (primitive_util::UnderflowExponent(to_type) - - primitive_util::SignificandWidth(to_type)) && - primitive_util::OverflowExponent(from_type) <= - primitive_util::OverflowExponent(to_type); + return + // Target mantissa should be large enough. + primitive_util::SignificandWidth(from_type) <= + primitive_util::SignificandWidth(to_type) && + // Target exponent should be large enough. + primitive_util::ExponentWidth(from_type) <= + primitive_util::ExponentWidth(to_type) && + // HasInfinity check. + (!primitive_util::HasInfinity(from_type) || + primitive_util::HasInfinity(to_type)) && + // HasNaN check. + (!primitive_util::HasNaN(from_type) || + primitive_util::HasNaN(to_type)) && + // HasNegativeZero check. + (!primitive_util::HasNegativeZero(from_type) || + primitive_util::HasNegativeZero(to_type)) && + // Minimum denormal should be representable by target type. + (primitive_util::UnderflowExponent(from_type) - + primitive_util::SignificandWidth(from_type)) >= + (primitive_util::UnderflowExponent(to_type) - + primitive_util::SignificandWidth(to_type)) && + // Maximum exponent may be larger with custom bias (e.g. F8E4M3B11FNUZ). + primitive_util::OverflowExponent(from_type) <= + primitive_util::OverflowExponent(to_type); } // F -> I is not safe because it drops fractional numbers. if (!primitive_util::IsIntegralType(from_type)) { diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index e4abeb4ff7ac9b..d0433f07e5a26c 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -69,8 +69,9 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][F8E4M3] = expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = expecteds[PRED][F8E3M4] = true; + expecteds[PRED][F4E2M1FN] = true; + expecteds[PRED][F8E8M0FNU] = false; expecteds[S1][PRED] = false; - expecteds[S2][PRED] = false; expecteds[S1][S1] = true; expecteds[S1][S2] = true; expecteds[S1][S4] = true; @@ -91,6 +92,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][C64] = true; expecteds[S1][BF16] = true; expecteds[S1][C128] = true; + expecteds[S1][F4E2M1FN] = true; expecteds[S1][F8E5M2] = true; expecteds[S1][F8E4M3] = true; expecteds[S1][F8E4M3FN] = true; @@ -98,8 +100,11 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S1][F8E5M2FNUZ] = true; expecteds[S1][F8E4M3FNUZ] = true; expecteds[S1][F8E3M4] = true; + expecteds[S1][F8E8M0FNU] = false; + expecteds[S2][PRED] = false; expecteds[S2][S1] = false; - expecteds[S2][S2] = expecteds[S2][S4] = true; + expecteds[S2][S2] = true; + expecteds[S2][S4] = true; expecteds[S2][S8] = true; expecteds[S2][S16] = true; expecteds[S2][S32] = true; @@ -117,6 +122,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][C64] = true; expecteds[S2][BF16] = true; expecteds[S2][C128] = true; + expecteds[S2][F4E2M1FN] = true; expecteds[S2][F8E5M2] = true; expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; @@ -124,6 +130,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; expecteds[S2][F8E3M4] = true; + expecteds[S2][F8E8M0FNU] = false; expecteds[S4][PRED] = false; expecteds[S4][S1] = false; expecteds[S4][S2] = false; @@ -145,6 +152,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][C64] = true; expecteds[S4][BF16] = true; expecteds[S4][C128] = true; + expecteds[S4][F4E2M1FN] = false; expecteds[S4][F8E5M2] = true; expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; @@ -152,6 +160,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; expecteds[S4][F8E3M4] = true; + expecteds[S4][F8E8M0FNU] = false; expecteds[S8][PRED] = false; expecteds[S8][S1] = false; expecteds[S8][S2] = false; @@ -173,6 +182,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][C64] = true; expecteds[S8][BF16] = true; expecteds[S8][C128] = true; + expecteds[S8][F4E2M1FN] = false; expecteds[S8][F8E5M2] = false; expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; @@ -180,6 +190,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; expecteds[S8][F8E3M4] = false; + expecteds[S8][F8E8M0FNU] = false; expecteds[S16][PRED] = false; expecteds[S16][S1] = false; expecteds[S16][S2] = false; @@ -201,6 +212,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][C64] = true; expecteds[S16][BF16] = false; expecteds[S16][C128] = true; + expecteds[S16][F4E2M1FN] = false; expecteds[S16][F8E5M2] = false; expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; @@ -208,6 +220,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; expecteds[S16][F8E3M4] = false; + expecteds[S16][F8E8M0FNU] = false; expecteds[S32][PRED] = false; expecteds[S32][S1] = false; expecteds[S32][S2] = false; @@ -229,6 +242,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][C64] = false; expecteds[S32][BF16] = false; expecteds[S32][C128] = true; + expecteds[S32][F4E2M1FN] = false; expecteds[S32][F8E5M2] = false; expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; @@ -236,6 +250,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; expecteds[S32][F8E3M4] = false; + expecteds[S32][F8E8M0FNU] = false; expecteds[S64][PRED] = false; expecteds[S64][S1] = false; expecteds[S64][S2] = false; @@ -257,6 +272,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][C64] = false; expecteds[S64][BF16] = false; expecteds[S64][C128] = false; + expecteds[S64][F4E2M1FN] = false; expecteds[S64][F8E5M2] = false; expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; @@ -264,6 +280,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; expecteds[S64][F8E3M4] = false; + expecteds[S64][F8E8M0FNU] = false; expecteds[U1][PRED] = false; expecteds[U1][S1] = false; expecteds[U1][S2] = true; @@ -285,8 +302,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][C64] = true; expecteds[U1][BF16] = true; expecteds[U1][C128] = true; - expecteds[U1][BF16] = true; - expecteds[U1][C128] = true; + expecteds[U1][F4E2M1FN] = true; expecteds[U1][F8E5M2] = true; expecteds[U1][F8E4M3] = true; expecteds[U1][F8E4M3FN] = true; @@ -294,14 +310,16 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U1][F8E5M2FNUZ] = true; expecteds[U1][F8E4M3FNUZ] = true; expecteds[U1][F8E3M4] = true; + expecteds[U1][F8E8M0FNU] = false; expecteds[U2][PRED] = false; - expecteds[U2][U1] = expecteds[U2][S1] = false; + expecteds[U2][S1] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; expecteds[U2][S8] = true; expecteds[U2][S16] = true; expecteds[U2][S32] = true; expecteds[U2][S64] = true; + expecteds[U2][U1] = false; expecteds[U2][U2] = true; expecteds[U2][U4] = true; expecteds[U2][U8] = true; @@ -314,8 +332,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][C64] = true; expecteds[U2][BF16] = true; expecteds[U2][C128] = true; - expecteds[U2][BF16] = true; - expecteds[U2][C128] = true; + expecteds[U2][F4E2M1FN] = true; expecteds[U2][F8E5M2] = true; expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; @@ -323,6 +340,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; expecteds[U2][F8E3M4] = true; + expecteds[U2][F8E8M0FNU] = false; expecteds[U4][PRED] = false; expecteds[U4][S1] = false; expecteds[U4][S2] = false; @@ -344,8 +362,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][C64] = true; expecteds[U4][BF16] = true; expecteds[U4][C128] = true; - expecteds[U4][BF16] = true; - expecteds[U4][C128] = true; + expecteds[U4][F4E2M1FN] = false; expecteds[U4][F8E5M2] = false; expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; @@ -353,6 +370,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; expecteds[U4][F8E3M4] = true; + expecteds[U4][F8E8M0FNU] = false; expecteds[U8][PRED] = false; expecteds[U8][S1] = false; expecteds[U8][S2] = false; @@ -374,8 +392,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][C64] = true; expecteds[U8][BF16] = true; expecteds[U8][C128] = true; - expecteds[U8][BF16] = true; - expecteds[U8][C128] = true; + expecteds[U8][F4E2M1FN] = false; expecteds[U8][F8E5M2] = false; expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; @@ -383,6 +400,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; expecteds[U8][F8E3M4] = false; + expecteds[U8][F8E8M0FNU] = false; expecteds[U16][PRED] = false; expecteds[U16][S1] = false; expecteds[U16][S2] = false; @@ -404,6 +422,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][C64] = true; expecteds[U16][BF16] = false; expecteds[U16][C128] = true; + expecteds[U16][F4E2M1FN] = false; expecteds[U16][F8E5M2] = false; expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; @@ -411,6 +430,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; expecteds[U16][F8E3M4] = false; + expecteds[U16][F8E8M0FNU] = false; expecteds[U32][PRED] = false; expecteds[U32][S1] = false; expecteds[U32][S2] = false; @@ -432,6 +452,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][C64] = false; expecteds[U32][BF16] = false; expecteds[U32][C128] = true; + expecteds[U32][F4E2M1FN] = false; expecteds[U32][F8E5M2] = false; expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; @@ -439,6 +460,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; expecteds[U32][F8E3M4] = false; + expecteds[U32][F8E8M0FNU] = false; expecteds[U64][PRED] = false; expecteds[U64][S1] = false; expecteds[U64][S2] = false; @@ -460,6 +482,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][C64] = false; expecteds[U64][BF16] = false; expecteds[U64][C128] = false; + expecteds[U64][F4E2M1FN] = false; expecteds[U64][F8E5M2] = false; expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; @@ -467,6 +490,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; expecteds[U64][F8E3M4] = false; + expecteds[U64][F8E8M0FNU] = false; expecteds[F16][PRED] = false; expecteds[F16][S1] = false; expecteds[F16][S2] = false; @@ -488,6 +512,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][C64] = true; expecteds[F16][BF16] = false; expecteds[F16][C128] = true; + expecteds[F16][F4E2M1FN] = false; expecteds[F16][F8E5M2] = false; expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; @@ -495,6 +520,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; expecteds[F16][F8E3M4] = false; + expecteds[F16][F8E8M0FNU] = false; expecteds[F32][PRED] = false; expecteds[F32][S1] = false; expecteds[F32][S2] = false; @@ -516,6 +542,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][C64] = true; expecteds[F32][BF16] = false; expecteds[F32][C128] = true; + expecteds[F32][F4E2M1FN] = false; expecteds[F32][F8E5M2] = false; expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; @@ -523,6 +550,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; expecteds[F32][F8E3M4] = false; + expecteds[F32][F8E8M0FNU] = false; expecteds[F64][PRED] = false; expecteds[F64][S1] = false; expecteds[F64][S2] = false; @@ -544,6 +572,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][C64] = false; expecteds[F64][BF16] = false; expecteds[F64][C128] = true; + expecteds[F64][F4E2M1FN] = false; expecteds[F64][F8E5M2] = false; expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; @@ -551,6 +580,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; expecteds[F64][F8E3M4] = false; + expecteds[F64][F8E8M0FNU] = false; expecteds[C64][PRED] = false; expecteds[C64][S1] = false; expecteds[C64][S2] = false; @@ -572,6 +602,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][C64] = true; expecteds[C64][BF16] = false; expecteds[C64][C128] = true; + expecteds[C64][F4E2M1FN] = false; expecteds[C64][F8E5M2] = false; expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; @@ -579,6 +610,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; expecteds[C64][F8E3M4] = false; + expecteds[C64][F8E8M0FNU] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S1] = false; expecteds[BF16][S2] = false; @@ -600,6 +632,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][C64] = true; expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; + expecteds[BF16][F4E2M1FN] = false; expecteds[BF16][F8E5M2] = false; expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; @@ -607,6 +640,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; expecteds[BF16][F8E3M4] = false; + expecteds[BF16][F8E8M0FNU] = false; expecteds[C128][PRED] = false; expecteds[C128][S1] = false; expecteds[C128][S2] = false; @@ -628,6 +662,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][C64] = false; expecteds[C128][BF16] = false; expecteds[C128][C128] = true; + expecteds[C128][F4E2M1FN] = false; expecteds[C128][F8E5M2] = false; expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; @@ -635,6 +670,37 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; expecteds[C128][F8E3M4] = false; + expecteds[C128][F8E8M0FNU] = false; + expecteds[F4E2M1FN][PRED] = false; + expecteds[F4E2M1FN][S1] = false; + expecteds[F4E2M1FN][S2] = false; + expecteds[F4E2M1FN][S4] = false; + expecteds[F4E2M1FN][S8] = false; + expecteds[F4E2M1FN][S16] = false; + expecteds[F4E2M1FN][S32] = false; + expecteds[F4E2M1FN][S64] = false; + expecteds[F4E2M1FN][U1] = false; + expecteds[F4E2M1FN][U2] = false; + expecteds[F4E2M1FN][U4] = false; + expecteds[F4E2M1FN][U8] = false; + expecteds[F4E2M1FN][U16] = false; + expecteds[F4E2M1FN][U32] = false; + expecteds[F4E2M1FN][U64] = false; + expecteds[F4E2M1FN][F16] = true; + expecteds[F4E2M1FN][F32] = true; + expecteds[F4E2M1FN][F64] = true; + expecteds[F4E2M1FN][C64] = true; + expecteds[F4E2M1FN][BF16] = true; + expecteds[F4E2M1FN][C128] = true; + expecteds[F4E2M1FN][F4E2M1FN] = true; + expecteds[F4E2M1FN][F8E5M2] = true; + expecteds[F4E2M1FN][F8E4M3] = true; + expecteds[F4E2M1FN][F8E4M3FN] = true; + expecteds[F4E2M1FN][F8E4M3B11FNUZ] = false; + expecteds[F4E2M1FN][F8E4M3FNUZ] = false; + expecteds[F4E2M1FN][F8E5M2FNUZ] = false; + expecteds[F4E2M1FN][F8E3M4] = true; + expecteds[F4E2M1FN][F8E8M0FNU] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S1] = false; expecteds[F8E5M2][S2] = false; @@ -656,6 +722,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][C64] = true; expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; + expecteds[F8E5M2][F4E2M1FN] = false; expecteds[F8E5M2][F8E5M2] = true; expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; @@ -663,6 +730,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E5M2][F8E8M0FNU] = false; expecteds[F8E4M3][PRED] = false; expecteds[F8E4M3][S1] = false; expecteds[F8E4M3][S2] = false; @@ -684,6 +752,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][C64] = true; expecteds[F8E4M3][BF16] = true; expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F4E2M1FN] = false; expecteds[F8E4M3][F8E5M2] = false; expecteds[F8E4M3][F8E5M2FNUZ] = false; expecteds[F8E4M3][F8E4M3] = true; @@ -691,6 +760,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][F8E4M3FNUZ] = false; expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3][F8E3M4] = false; + expecteds[F8E4M3][F8E8M0FNU] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S1] = false; expecteds[F8E4M3FN][S2] = false; @@ -712,6 +782,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][C64] = true; expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; + expecteds[F8E4M3FN][F4E2M1FN] = false; expecteds[F8E4M3FN][F8E5M2] = false; expecteds[F8E4M3FN][F8E5M2FNUZ] = false; expecteds[F8E4M3FN][F8E4M3] = false; @@ -719,6 +790,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][F8E3M4] = false; + expecteds[F8E4M3FN][F8E8M0FNU] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S1] = false; expecteds[F8E4M3B11FNUZ][S2] = false; @@ -740,6 +812,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][C64] = true; expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; + expecteds[F8E4M3B11FNUZ][F4E2M1FN] = false; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; @@ -747,6 +820,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E3M4] = false; + expecteds[F8E4M3B11FNUZ][F8E8M0FNU] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S1] = false; expecteds[F8E5M2FNUZ][S2] = false; @@ -768,6 +842,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][C64] = true; expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; + expecteds[F8E5M2FNUZ][F4E2M1FN] = false; expecteds[F8E5M2FNUZ][F8E5M2] = false; expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; @@ -775,6 +850,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][F8E3M4] = false; + expecteds[F8E5M2FNUZ][F8E8M0FNU] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S1] = false; expecteds[F8E4M3FNUZ][S2] = false; @@ -796,6 +872,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][C64] = true; expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; + expecteds[F8E4M3FNUZ][F4E2M1FN] = false; expecteds[F8E4M3FNUZ][F8E5M2] = false; expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; @@ -803,6 +880,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E4M3FNUZ][F8E8M0FNU] = false; expecteds[F8E3M4][PRED] = false; expecteds[F8E3M4][S1] = false; expecteds[F8E3M4][S2] = false; @@ -824,6 +902,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][C64] = true; expecteds[F8E3M4][BF16] = true; expecteds[F8E3M4][C128] = true; + expecteds[F8E3M4][F4E2M1FN] = false; expecteds[F8E3M4][F8E5M2] = false; expecteds[F8E3M4][F8E5M2FNUZ] = false; expecteds[F8E3M4][F8E4M3] = false; @@ -831,6 +910,37 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][F8E4M3FNUZ] = false; expecteds[F8E3M4][F8E4M3B11FNUZ] = false; expecteds[F8E3M4][F8E3M4] = true; + expecteds[F8E3M4][F8E8M0FNU] = false; + expecteds[F8E8M0FNU][PRED] = false; + expecteds[F8E8M0FNU][S1] = false; + expecteds[F8E8M0FNU][S2] = false; + expecteds[F8E8M0FNU][S4] = false; + expecteds[F8E8M0FNU][S8] = false; + expecteds[F8E8M0FNU][S16] = false; + expecteds[F8E8M0FNU][S32] = false; + expecteds[F8E8M0FNU][S64] = false; + expecteds[F8E8M0FNU][U1] = false; + expecteds[F8E8M0FNU][U2] = false; + expecteds[F8E8M0FNU][U4] = false; + expecteds[F8E8M0FNU][U8] = false; + expecteds[F8E8M0FNU][U16] = false; + expecteds[F8E8M0FNU][U32] = false; + expecteds[F8E8M0FNU][U64] = false; + expecteds[F8E8M0FNU][F16] = false; + expecteds[F8E8M0FNU][F32] = true; + expecteds[F8E8M0FNU][F64] = true; + expecteds[F8E8M0FNU][C64] = true; + expecteds[F8E8M0FNU][BF16] = true; + expecteds[F8E8M0FNU][C128] = true; + expecteds[F8E8M0FNU][F4E2M1FN] = false; + expecteds[F8E8M0FNU][F8E5M2] = false; + expecteds[F8E8M0FNU][F8E4M3] = false; + expecteds[F8E8M0FNU][F8E4M3FN] = false; + expecteds[F8E8M0FNU][F8E4M3B11FNUZ] = false; + expecteds[F8E8M0FNU][F8E4M3FNUZ] = false; + expecteds[F8E8M0FNU][F8E5M2FNUZ] = false; + expecteds[F8E8M0FNU][F8E3M4] = false; + expecteds[F8E8M0FNU][F8E8M0FNU] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { @@ -851,7 +961,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { << primitive_util::LowercasePrimitiveTypeName(to_type); } } -} +} // NOLINT(readability/fn_size) } // namespace } // namespace xla diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index ed68a1d11403c2..58289385fd7d23 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -32,6 +32,7 @@ std::optional DType::byte_size() const { case kU2: case kS4: case kU4: + case kF4E2M1FN: // Smaller than a byte. return std::nullopt; case kPred: @@ -39,6 +40,7 @@ std::optional DType::byte_size() const { case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -77,12 +79,14 @@ std::optional DType::bit_size() const { return 2; case kS4: case kU4: + case kF4E2M1FN: return 4; case kPred: case kS8: case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -141,9 +145,11 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); + CASE(F4E2M1FN); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // CASE(F8E3M4); // CASE(F8E4M3); + CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -189,9 +195,11 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); + CASE(F4E2M1FN); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // CASE(F8E3M4); // CASE(F8E4M3); + CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index d23efc55a1aa12..864cdd1c063ae4 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -88,8 +88,12 @@ class DType { kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, + kF8E8M0FNU = 33, - // Next = 30 + // MX floating point types. + kF4E2M1FN = 32, + + // Next = 34 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index 3a2b0df7976d6e..2cf453f26c291d 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -70,12 +70,18 @@ message DTypeProto { KIND_F8E4M3FNUZ = 25; KIND_F8E5M2 = 19; KIND_F8E5M2FNUZ = 24; + KIND_F8E8M0FNU = 31; + + // MX floating point types. + KIND_F4E2M1FN = 30; // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind // needs to match xla.PrimitiveType enum, so choose a large enum to avoid // collision. KIND_STRING = 99; + + // Next: 32 } // LINT.ThenChange() Kind kind = 1; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 57fec6702d277d..9d3d3105f54e54 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -42,34 +42,21 @@ TEST(DTypeTest, FromToFromProto) { TEST(DTypeTest, ByteSize) { for (const auto& [kind, byte_size] : std::vector>({ - {DType::kS2, -1}, - {DType::kU2, -1}, - {DType::kS4, -1}, - {DType::kU4, -1}, - {DType::kPred, 1}, - {DType::kS8, 1}, - {DType::kU8, 1}, - {DType::kF8E3M4, 1}, - {DType::kF8E4M3, 1}, - {DType::kF8E4M3FN, 1}, - {DType::kF8E4M3B11FNUZ, 1}, - {DType::kF8E4M3FNUZ, 1}, - {DType::kF8E5M2, 1}, - {DType::kF8E5M2FNUZ, 1}, - {DType::kS16, 2}, - {DType::kU16, 2}, - {DType::kF16, 2}, - {DType::kBF16, 2}, - {DType::kS32, 4}, - {DType::kU32, 4}, - {DType::kF32, 4}, - {DType::kS64, 8}, - {DType::kU64, 8}, - {DType::kF64, 8}, - {DType::kC64, 8}, - {DType::kC128, 16}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, + {DType::kS2, -1}, {DType::kU2, -1}, + {DType::kS4, -1}, {DType::kU4, -1}, + {DType::kPred, 1}, {DType::kS8, 1}, + {DType::kU8, 1}, {DType::kF4E2M1FN, -1}, + {DType::kF8E3M4, 1}, {DType::kF8E4M3, 1}, + {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, + {DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1}, + {DType::kF8E5M2FNUZ, 1}, {DType::kF8E8M0FNU, 1}, + {DType::kS16, 2}, {DType::kU16, 2}, + {DType::kF16, 2}, {DType::kBF16, 2}, + {DType::kS32, 4}, {DType::kU32, 4}, + {DType::kF32, 4}, {DType::kS64, 8}, + {DType::kU64, 8}, {DType::kF64, 8}, + {DType::kC64, 8}, {DType::kC128, 16}, + {DType::kToken, -1}, {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).byte_size(), @@ -80,34 +67,21 @@ TEST(DTypeTest, ByteSize) { TEST(DTypeTest, BitSize) { for (const auto& [kind, bit_size] : std::vector>({ - {DType::kS2, 2}, - {DType::kU2, 2}, - {DType::kS4, 4}, - {DType::kU4, 4}, - {DType::kPred, 8}, - {DType::kS8, 8}, - {DType::kU8, 8}, - {DType::kF8E3M4, 8}, - {DType::kF8E4M3, 8}, - {DType::kF8E4M3FN, 8}, - {DType::kF8E4M3B11FNUZ, 8}, - {DType::kF8E4M3FNUZ, 8}, - {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, - {DType::kS16, 16}, - {DType::kU16, 16}, - {DType::kF16, 16}, - {DType::kBF16, 16}, - {DType::kS32, 32}, - {DType::kU32, 32}, - {DType::kF32, 32}, - {DType::kS64, 64}, - {DType::kU64, 64}, - {DType::kF64, 64}, - {DType::kC64, 64}, - {DType::kC128, 128}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, + {DType::kS2, 2}, {DType::kU2, 2}, + {DType::kS4, 4}, {DType::kU4, 4}, + {DType::kPred, 8}, {DType::kS8, 8}, + {DType::kU8, 8}, {DType::kF4E2M1FN, 4}, + {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, + {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, + {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, + {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 8}, + {DType::kS16, 16}, {DType::kU16, 16}, + {DType::kF16, 16}, {DType::kBF16, 16}, + {DType::kS32, 32}, {DType::kU32, 32}, + {DType::kF32, 32}, {DType::kS64, 64}, + {DType::kU64, 64}, {DType::kF64, 64}, + {DType::kC64, 64}, {DType::kC128, 128}, + {DType::kToken, -1}, {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).bit_size(), diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index 9c581ec6227cae..2af3281a588cce 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -44,6 +44,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF4E2M1FN, xla::PrimitiveType::F4E2M1FN); CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); @@ -51,6 +52,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); + CASE(DType::kF8E8M0FNU, xla::PrimitiveType::F8E8M0FNU); CASE(DType::kF16, xla::PrimitiveType::F16); CASE(DType::kF32, xla::PrimitiveType::F32); CASE(DType::kBF16, xla::PrimitiveType::BF16); @@ -83,6 +85,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: + case xla::PrimitiveType::F4E2M1FN: case xla::PrimitiveType::F8E3M4: case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: @@ -90,6 +93,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::F8E4M3FNUZ: case xla::PrimitiveType::F8E5M2: case xla::PrimitiveType::F8E5M2FNUZ: + case xla::PrimitiveType::F8E8M0FNU: case xla::PrimitiveType::F16: case xla::PrimitiveType::F32: case xla::PrimitiveType::BF16: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 631b0bcb9b9562..45baa4abf79351 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -184,6 +184,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E3M4; @@ -205,6 +208,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; } else if (std::is_same() || !options.squash_64bit_types) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); type = primitive_util::NativeToPrimitiveType(); @@ -398,6 +404,10 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } if (dtypes.np_float8_e3m4.has_value()) { (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; @@ -415,6 +425,10 @@ absl::StatusOr DevicePut(nb::handle arg, HandleNumpyScalar; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; @@ -595,8 +609,10 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 50366be350bc08..473c082e1425cc 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -58,6 +58,7 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + std::optional float4_e2m1fn; std::optional float8_e3m4; std::optional float8_e4m3; nb_dtype float8_e4m3fn; @@ -65,6 +66,7 @@ struct CustomDtypes { nb_dtype float8_e4m3fnuz; nb_dtype float8_e5m2; nb_dtype float8_e5m2fnuz; + std::optional float8_e8m0fnu; std::optional int2; nb_dtype int4; std::optional uint2; @@ -76,6 +78,10 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { + dtypes->float4_e2m1fn = + nb_dtype::from_args(ml_dtypes.attr("float4_e2m1fn")); + } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); } @@ -91,6 +97,10 @@ const CustomDtypes& GetCustomDtypes() { nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->float8_e5m2fnuz = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->float8_e8m0fnu = + nb_dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4")); dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4")); if (nb::hasattr(ml_dtypes, "int2")) { @@ -147,6 +157,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + if (custom_dtypes.float4_e2m1fn.has_value()) { + map->emplace(*custom_dtypes.float4_e2m1fn, F4E2M1FN); + } if (custom_dtypes.float8_e3m4.has_value()) { map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); } @@ -158,6 +171,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); map->emplace(custom_dtypes.float8_e5m2, F8E5M2); map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ); + if (custom_dtypes.float8_e8m0fnu.has_value()) { + map->emplace(*custom_dtypes.float8_e8m0fnu, F8E8M0FNU); + } if (custom_dtypes.int2.has_value()) { map->emplace(*custom_dtypes.int2, S2); } @@ -217,6 +233,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F4E2M1FN: + if (custom_dtypes.float4_e2m1fn.has_value()) { + return *custom_dtypes.float4_e2m1fn; + } + break; case F8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -237,6 +258,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return custom_dtypes.float8_e5m2; case F8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case F8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case BF16: return custom_dtypes.bfloat16; case F16: @@ -307,6 +333,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF4E2M1FN: + if (custom_dtypes.float4_e2m1fn.has_value()) { + return *custom_dtypes.float4_e2m1fn; + } + break; case ifrt::DType::kF8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -327,6 +358,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return custom_dtypes.float8_e5m2; case ifrt::DType::kF8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case ifrt::DType::kF8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case ifrt::DType::kString: // PEP 3118 code for "pointer to Python Object". We use Python objects // instead of 'U' (Unicode string) or 'V' (raw data) because the latter @@ -380,6 +416,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { + dtypes->np_float4_e2m1fn = nb::object(ml_dtypes.attr("float4_e2m1fn")); + } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); } @@ -392,6 +431,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_float8_e5m2 = nb::object(ml_dtypes.attr("float8_e5m2")); dtypes->np_float8_e4m3fnuz = nb::object(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->np_float8_e5m2fnuz = nb::object(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->np_float8_e8m0fnu = nb::object(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->np_float16 = nb::object(numpy.attr("float16")); dtypes->np_float32 = nb::object(numpy.attr("float32")); dtypes->np_float64 = nb::object(numpy.attr("float64")); diff --git a/xla/python/types.h b/xla/python/types.h index aacfea1a17997f..babdf5a9bd4167 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -81,6 +81,7 @@ struct NumpyScalarTypes { nanobind::object np_uint64; nanobind::object np_bfloat16; // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. + std::optional np_float4_e2m1fn; std::optional np_float8_e3m4; std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; @@ -88,6 +89,7 @@ struct NumpyScalarTypes { nanobind::object np_float8_e4m3fnuz; nanobind::object np_float8_e5m2; nanobind::object np_float8_e5m2fnuz; + std::optional np_float8_e8m0fnu; nanobind::object np_float16; nanobind::object np_float32; nanobind::object np_float64; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 0085e3224efe20..9a9e019c3d3298 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -204,9 +204,11 @@ NB_MODULE(xla_extension, m) { .value("U32", U32) .value("U64", U64) .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. // .value("F8E3M4", F8E3M4) // .value("F8E4M3", F8E4M3) + .value("F8E8M0FNU", F8E8M0FNU) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 46dd4a72edd1e7..0781c88b5d8f48 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -282,8 +282,12 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): bfloat16 = ml_dtypes.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# Also, it would be better to conditionally import these based on whether they +# are in the current version of ml_dtypes. +# float4_e2m1fn = ml_dtypes.float4_e2m1fn # float8_e3m4 = ml_dtypes.float8_e3m4 # float8_e4m3 = ml_dtypes.float8_e4m3 +# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -303,8 +307,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), + # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index efc3d2573b2224..cf3370fb5ddac5 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -62,8 +62,10 @@ mlir_api_version: int bfloat16: type[numpy.generic] # TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn: type[numpy.generic] # float8_e3m4: type[numpy.generic] # float8_e4m3: type[numpy.generic] +# float8_e8m0fnu: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index f0cecc9903295e..bd45255f77bb57 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -55,8 +55,10 @@ bfloat16 = xla_client.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn = xla_client.float4_e2m1fn # float8_e3m4 = xla_client.float8_e3m4 # float8_e4m3 = xla_client.float8_e4m3 +# float8_e8m0fnu = xla_client.float8_e8m0fnu float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -189,7 +191,7 @@ def TestFactory(xla_backend, fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] standard_dtypes += fp8_dtypes # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float8_e3m4, float8_e4m3] + # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index ec3ff508a21cb9..cb7e3c90ab711b 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -74,6 +74,7 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F4E2M1FN: PrimitiveType F8E3M4: PrimitiveType F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType @@ -81,6 +82,7 @@ class PrimitiveType(enum.IntEnum): F8E4M3FNUZ: PrimitiveType F8E5M2: PrimitiveType F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType BF16: PrimitiveType F16: PrimitiveType F32: PrimitiveType diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 5c28de6021def4..d6d08d5926e343 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -602,6 +602,10 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&s4_support); FloatSupport u4_support(U4, U8); pipeline.AddPass(&u4_support); + FloatSupport f4e2m1fn_support(F4E2M1FN, F16); + pipeline.AddPass(&f4e2m1fn_support); + FloatSupport f8e8m0fnu_support(F8E8M0FNU, F32); + pipeline.AddPass(&f8e8m0fnu_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 18841d2712dcbc..90c4f6c82e4082 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ, F8E4M3, F8E3M4 + // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN, F8E8M0FNU default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 083f07b8bc8fc3..83756d35eb4e3d 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -864,6 +864,223 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, return f16_value; } +absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, + llvm::IRBuilderBase* b) { + auto i8_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt8Ty(), val); + }; + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // Cast the input value to an integer for bitwise manipulation. + // Get the absolute value of the input (discard the sign). + // f16_bits = bitcast(f16_value, int) + // f16_abs_bits = f16_bits & 0x7FFF + llvm::Value* f16_bits = b->CreateBitCast(f16_value, b->getInt16Ty()); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_bits, i16_const(0x7FFF)); + + // If the input absolute value is >= 7.0 or an infinity, the result saturates + // to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0. + // If (0 <= input <= 0.25), the result is rounded to 0.0. + // If the input is NaN, the result is undefined (implemented as minus zero). + // The rest of the cases are handled by the "happy path". + // is_overflow = f16_abs_bits >= 0x1.Cp2 + // is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows) + // is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows) + // is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold) + llvm::Value* is_overflow = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x4700)); // 7.0 + llvm::Value* is_one = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x3A00)); // 0.75 + llvm::Value* is_zero = + b->CreateICmpULE(f16_abs_bits, i16_const(0x3400)); // 0.25 + llvm::Value* is_nan = + b->CreateICmpUGT(f16_abs_bits, i16_const(0x7C00)); // inf + + // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as + // the type doesn't have Inf/NaN and can represent unbiased exponent 2). + // This case, as well as the denormal, is handled below. + TF_ASSIGN_OR_RETURN( + llvm::Value * reduced_precision, + EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1, + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1, + /*quiet_nans=*/false, b)); + + // Cast the reduced precision value to an integer for bitwise manipulation. + // Discard the least significant (9) mantissa bits leaving 1 bit. + // Truncate to + // as_int16 = bitcast(reduced_precision, int) + // as_int8 = as_int16 >> (f16_mantissa - f4_mantissa) + llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty()); + llvm::Value* as_int8 = + b->CreateTrunc(b->CreateLShr(as_int16, mantissa_diff), b->getInt8Ty()); + + // Get the sign (0 or 1). + // f4_sign = as_int8 >> 6 + llvm::Value* f4_sign = b->CreateLShr(as_int8, 6); + + // Get exponent and mantissa bits without the sign. + // Important: the mask is 0x3F (not 0x7F), discard bit #6. + // f4_bits = as_int8 & 0x3F + llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F)); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // This produces the "normal" result, i.e. not Inf or NaN or denormal. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f4_exponent_offset = bias_diff << 1; + llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(f4_exponent_offset)); + + // If the rounding resulted in zero exponent, the value is incorrect. + // This happens when the input is < 1.0 + // is_underflow = f4_normal <= 1 + llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1)); + + // Chain of selects that handles the special cases. + // f4_result = + // is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) : + // is_overflow ? (is_nan ? -0.0 : 6.0) : + // f4_normal + llvm::Value* f4_result = b->CreateSelect( + is_underflow, + // If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0 + b->CreateSelect(is_one, i8_const(0x2), + b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1))), + // If overflow, the input is >= 7.0 or infinity or NaN. + b->CreateSelect(is_overflow, + b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)), + f4_normal)); + + // Add sign to the resulting value. + // f4_signed_result = (f4_sign << 3) | f4_result + return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3)); +} + +llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilderBase* b) { + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // The input value is a 8-bit integer, extend it to 16-bit integer. + // as_int16 = bitcast(f8_value, int) + llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); + + // Get the sign and shift it to F16 position. + // f4_sign = as_int16 >> 3 + // f16_sign_bit = f4_sign << 15 + llvm::Value* f4_sign = b->CreateLShr(as_int16, 3); + llvm::Value* f16_sign_bit = b->CreateShl(f4_sign, 15); + + // Get exponent and mantissa bits without the sign. + // f4_bits = as_int16 & 0x7 + // f16_bits = f4_bits << (f16_mantissa - f4_mantissa) + llvm::Value* f4_bits = b->CreateAnd(as_int16, i16_const(0x7)); + llvm::Value* f16_bits = b->CreateShl(f4_bits, mantissa_diff); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f16_exponent_offset = bias_diff << 10; + llvm::Value* f16_normal = + b->CreateAdd(f16_bits, i16_const(f16_exponent_offset)); + + // For denormal and zero, the exponent is different. Handle these cases + // separately below. + // is_denorm_or_zero = f4_bits <= 1 + // is_zero = f4_bits == 0 + llvm::Value* is_denorm_or_zero = b->CreateICmpULE(f4_bits, i16_const(1)); + llvm::Value* is_zero = b->CreateICmpEQ(f4_bits, i16_const(0)); + + // Chain of selects that handles the special cases. + // f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal + llvm::Value* f16_result = b->CreateSelect( + is_denorm_or_zero, + b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)), + f16_normal); + + // Add sign to the resulting value. + // f16_signed_result = f16_sign_bit | f16_result + llvm::Value* f16_signed_result = b->CreateOr(f16_result, f16_sign_bit); + return b->CreateBitCast(f16_signed_result, b->getHalfTy()); +} + +llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, + llvm::IRBuilderBase* b) { + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + + // Cast the input value to an integer for bitwise manipulation. + // as_int32 = bitcast(f32_value, int) + llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); + + // Check if the input is zero, negative, overflow, infinity or NaN. + // All of these cases cannot be represented in the E8M0 format. + // is_zero_or_negative = as_int32 <= 0 + // is_overflow_or_nan = as_int32 >= 0x1.8p127 + // is_nan = is_zero_or_negative | is_overflow_or_nan + llvm::Value* is_zero_or_negative = b->CreateICmpSLE(as_int32, i32_const(0)); + llvm::Value* is_overflow_or_nan = + b->CreateICmpSGE(as_int32, i32_const(0x7F400000)); // 1.5 * 2^127 + llvm::Value* is_nan = b->CreateOr(is_zero_or_negative, is_overflow_or_nan); + + // Check if the input is a denormal which should round to the minimum value + // (2^-127), as there is no zero value. + // is_denorm = as_int32 <= 0x1p-127 + llvm::Value* is_denorm = + b->CreateICmpULE(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + + // Round the value (always up) and discard the mantissa. + // rounded = as_int32 + 0x1p-127 + // f8_normal = as_int32 >> f32_mantissa + llvm::Value* rounded = + b->CreateAdd(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + llvm::Value* f8_normal = b->CreateAShr(rounded, 23); + + // Chain of selects that handles the special cases. + // f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal) + llvm::Value* f8_result = + b->CreateSelect(is_nan, i32_const(0xFF), + b->CreateSelect(is_denorm, i32_const(0x00), f8_normal)); + + // Truncate to the result type. + return b->CreateTrunc(f8_result, b->getInt8Ty()); +} + +llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilderBase* b) { + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + + // The input value is a 8-bit integer, extend it to 32-bit integer. + // as_int32 = bitcast(f8_value, int) + llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); + + // Check if the input is a denormal or NaN. + // is_zero = as_int32 == 0x00 + // is_nan = as_int32 == 0xFF + llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0)); + llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF)); + + // Shift exponent to the left for the normal case. + // f32_normal = as_int32 << mantissa_diff + llvm::Value* f32_normal = b->CreateShl(as_int32, 23); + + // Chain of selects that handles the special cases. + // f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal) + llvm::Value* f32_result = b->CreateSelect( + is_nan, i32_const(0x7FC00000), + b->CreateSelect(is_zero, i32_const(0x400000), f32_normal)); + + // Bitcast integer bits to the result type. + return b->CreateBitCast(f32_result, b->getFloatTy()); +} + llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, @@ -958,6 +1175,18 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F4E2M1FN) { + return EmitF16ToF4e2m1fn( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } + if (to_type == F8E8M0FNU) { + return EmitF32ToF8e8m0fnu( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + b_), + b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz( F16, @@ -1163,10 +1392,29 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F4E2M1FN) { + TF_RET_CHECK(to_type != F4E2M1FN); + operand_value = EmitF4e2m1fnToF16(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F8E8M0FNU) { + TF_RET_CHECK(to_type != F8E8M0FNU); + operand_value = EmitF8e8m0fnuToF32(operand_value, b_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { TF_RET_CHECK(to_type != from_type); PrimitiveType cast_type = primitive_util::IsFloatingPointType(to_type) ? to_type : F16; + if (to_type == F8E8M0FNU || to_type == F4E2M1FN) { + cast_type = F32; + } TF_ASSIGN_OR_RETURN(operand_value, EmitF8fnuzToFloating(from_type, operand_value, cast_type, b_, module_)); @@ -1249,6 +1497,24 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } + if (to_type == F4E2M1FN) { + // Cast to F16 first. Casts to F4E2M1FN must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); + } + return EmitF16ToF4e2m1fn(operand_value, b_); + } + if (to_type == F8E8M0FNU) { + // Cast to F32 first. Casts to F8E8M0FNU must be from F32. + if (from_type != F32) { + operand_value = b_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(F32, module_->getContext())); + } + return EmitF32ToF8e8m0fnu(operand_value, b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } @@ -1809,6 +2075,12 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); + } else if (operand_type == F4E2M1FN) { + lhs_value = EmitF4e2m1fnToF16(lhs_value, b_); + rhs_value = EmitF4e2m1fnToF16(rhs_value, b_); + } else if (operand_type == F8E8M0FNU) { + lhs_value = EmitF8e8m0fnuToF32(lhs_value, b_); + rhs_value = EmitF8e8m0fnuToF32(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { TF_ASSIGN_OR_RETURN( lhs_value, @@ -3663,10 +3935,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; - if (component_element_type == F8E4M3FNUZ) { - float_ir_type = - llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); - } else if (component_element_type == F8E5M2FNUZ) { + if (component_element_type == F8E4M3FNUZ || + component_element_type == F8E5M2FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); } else { diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index f947aa8ada14c0..0d906f47b4c474 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -100,9 +100,10 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -615,7 +616,9 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -630,6 +633,10 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { auto tname = this->TypeName(); + if (std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } const auto hlo_text = absl::StrReplaceAll(R"( HloModule matmul diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index 4afb96362cf86e..e0be95da5f6680 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -40,6 +40,8 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F4E2M1FN: + return &llvm::APFloat::Float4E2M1FN(); case F8E3M4: return &llvm::APFloat::Float8E3M4(); case F8E4M3: @@ -54,6 +56,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( return &llvm::APFloat::Float8E5M2(); case F8E5M2FNUZ: return &llvm::APFloat::Float8E5M2FNUZ(); + case F8E8M0FNU: + return &llvm::APFloat::Float8E8M0FNU(); case BF16: return &llvm::APFloat::BFloat(); case F16: @@ -72,6 +76,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, PrimitiveType type) { switch (type) { + case F4E2M1FN: + return b->getIntNTy(4); case F8E3M4: case F8E4M3: case F8E4M3B11FNUZ: @@ -79,6 +85,7 @@ absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, case F8E4M3FNUZ: case F8E5M2: case F8E5M2FNUZ: + case F8E8M0FNU: return b->getInt8Ty(); case BF16: return b->getBFloatTy(); @@ -649,8 +656,14 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign); // Bitwise or the sign bit back in. - sign = b->CreateZExt(sign, output_int_type); - sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type)); + int shift = output_type_bit_width - BitWidth(input_type); + if (shift >= 0) { + sign = b->CreateZExt(sign, output_int_type); + sign = b->CreateShl(sign, shift); + } else { + sign = b->CreateLShr(sign, -shift); + sign = b->CreateTrunc(sign, output_int_type); + } llvm::Value* result = b->CreateOr(sign, result_abs); // Bitcast to the output type. diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index 0e5a2ffe7a6a60..2fd6b3a74977e6 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -550,11 +550,19 @@ INSTANTIATE_TEST_SUITE_P( using ReduceTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam; +static absl::string_view init_value(PrimitiveType dtype) { + if (dtype == C64 || dtype == C128) { + return "(0, 0)"; + } else if (dtype == F8E8M0FNU) { + return "1e-40"; + } else { + return "0"; + } +} + TEST_P(ReduceTest, IsTritonSupportedReduction) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( add { Arg_0 = $$0[] parameter(0) Arg_1 = $$0[] parameter(1) @@ -567,7 +575,7 @@ ENTRY triton_computation { ROOT reduce = $$0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=add })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -599,9 +607,7 @@ TEST_P( ReduceTest, UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( add { Arg_0 = $$0[] parameter(0) Arg_1 = $$0[] parameter(1) @@ -614,7 +620,7 @@ ENTRY triton_computation { ROOT reduce = $$0[2] reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -624,9 +630,7 @@ ENTRY triton_computation { TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( add { Arg_0 = $$0[] parameter(0) Arg_1 = $$0[] parameter(1) @@ -638,7 +642,7 @@ ENTRY triton_computation { constant_0 = $$0[] constant($0) ROOT reduce = $$0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -649,9 +653,7 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( add { Arg_0 = $$0[] parameter(0) Arg_1 = $$0[] parameter(1) @@ -670,7 +672,7 @@ ENTRY triton_computation { dimensions={1}, to_apply=add ROOT reduce = $$0[125] get-tuple-element(tuple), index=0 })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -701,9 +703,7 @@ ENTRY triton_computation { TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( custom_call { Arg_0 = $$0[] parameter(0) Arg_1 = $$0[] parameter(1) @@ -716,7 +716,7 @@ ENTRY triton_computation { ROOT reduce = $$0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -740,7 +740,6 @@ using ReductionComputationTest = // computation and in regular HLO. See triton_support.cc for more details. TEST_P(ReductionComputationTest, DifferentBinaryOps) { auto [data_type, opcode, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = absl::Substitute( R"( reduce_computation { @@ -755,7 +754,7 @@ ENTRY triton_computation { ROOT reduce = $$0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=reduce_computation })", - HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0"); + HloOpcodeString(opcode), init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -1115,13 +1114,11 @@ TEST_P(ConstantTest, ConstantEffectiveScalar) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $$0[1,1] constant({{$0}}) })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, @@ -1133,13 +1130,11 @@ TEST_P(ConstantTest, Constant2D) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. auto [data_type, cc] = GetParam(); - bool dtype_is_complex = data_type == C64 || data_type == C128; - const std::string kHloTestTemplate = - absl::Substitute(R"( + const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { ROOT const = $$0[3,3] constant({{$0,$0,$0},{$0,$0,$0},{$0,$0,$0}}) })", - dtype_is_complex ? "(0, 0)" : "0"); + init_value(data_type)); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 2c84b3fbbc0ff5..fa8baed8bdfece 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1484,6 +1484,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); const GpuFloatSupport s4_support(gpu_version, S4, S8); const GpuFloatSupport u4_support(gpu_version, U4, U8); + const GpuFloatSupport f4e2m1fn_support(gpu_version, F4E2M1FN, F16); + const GpuFloatSupport f8e8m0fnu_support(gpu_version, F8E8M0FNU, F32); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1497,6 +1499,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e3m4_support); sub_pipeline.AddPass(&s4_support); sub_pipeline.AddPass(&u4_support); + sub_pipeline.AddPass(&f4e2m1fn_support); + sub_pipeline.AddPass(&f8e8m0fnu_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 16383324dfb016..6e0e14e320a7f9 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,9 +29,10 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f8e5m2", "f8e5m2fnuz", "f8e4m3", - "f8e4m3fn", "f8e4m3fnuz", - "f8e4m3b11fnuz", "f8e3m4")); + "f4e2m1fn", "f8e5m2", "f8e5m2fnuz", + "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4", + "f8e8m0fnu")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 9e84f287beb874..4ba73f6042043b 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2993,9 +2993,10 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); if (instruction->opcode() == HloOpcode::kConvert || instruction->opcode() == HloOpcode::kCompare || + instruction->opcode() == HloOpcode::kIsFinite || (instruction->opcode() == HloOpcode::kSelect && operand_shape.element_type() == PRED)) { - // Convert and Compare instructions can change element_size_in_bits + // Some instructions can change element_size_in_bits // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index ff7c4e84a19b00..0edf70328dfad6 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -199,6 +199,8 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case S16: case U16: return llvm::Type::getInt16Ty(context); + case F4E2M1FN: + return llvm::Type::getIntNTy(context, 4); case F8E5M2: case F8E5M2FNUZ: case F8E4M3: @@ -206,6 +208,7 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case F8E4M3B11FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(context); case BF16: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index f5246389e485c3..e3e7d1f17e312f 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,10 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF4E2M1FN; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E3M4; }; @@ -61,6 +65,10 @@ struct ToDataType { static constexpr DataType value = DataType::kF8E5M2FNUZ; }; template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E8M0FNU; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kFloat; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 6b7a87d80b3aec..24851e56d75eda 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -69,12 +69,14 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6aee86bf2cbc19..182af599af9e5c 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -56,6 +56,10 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E4M3FNUZ; case PrimitiveType::F8E3M4: return DataType::kF8E3M4; + case PrimitiveType::F4E2M1FN: + return DataType::kF4E2M1FN; + case PrimitiveType::F8E8M0FNU: + return DataType::kF8E8M0FNU; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -93,6 +97,10 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E4M3FNUZ; case DataType::kF8E3M4: return PrimitiveType::F8E3M4; + case DataType::kF4E2M1FN: + return PrimitiveType::F4E2M1FN; + case DataType::kF8E8M0FNU: + return PrimitiveType::F8E8M0FNU; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -154,6 +162,8 @@ absl::StatusOr GetBlasComputationType( case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through case PrimitiveType::F8E3M4: // fall-through + case PrimitiveType::F4E2M1FN: // fall-through + case PrimitiveType::F8E8M0FNU: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index e5730121addd8d..8864476bf0d825 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -39,8 +39,10 @@ hipDataType AsHipblasDataType(blas::DataType type) { case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: - LOG(FATAL) - << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; + case blas::DataType::kF4E2M1FN: + case blas::DataType::kF8E8M0FNU: + LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " + "F8E3M4, F4E2M1FN and F8E8M0FNU"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index fde51c9d99b16d..867eca75893f95 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -88,6 +88,20 @@ std::pair, std::vector> AllSignedPairs( return {xs, ys}; } +template +void AddNegativeValuesMaybeRemoveZero(std::vector& values) { + values.reserve(values.size() * 2); + if (!has_zero_v) { + values.erase(values.begin()); + } + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto neg = -values[i]; + if (SignAndMagnitude(neg).first) { + values.push_back(neg); + } + } +} + class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: static constexpr float kEpsF32 = std::numeric_limits::epsilon(); @@ -1366,14 +1380,7 @@ class TotalOrderTest : public ClientLibraryTestBase { values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); } #endif - values.reserve(values.size() * 2); - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); std::vector lhs_data; std::vector rhs_data; lhs_data.reserve(values.size() * values.size()); @@ -1418,19 +1425,24 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types; +#if !defined(XLA_TEST_BACKEND_TPU) + // TODO(b/385004399): Run tests on these types on TPU. + tsl::float4_e2m1fn, tsl::float8_e8m0fnu, +#endif + float>; TYPED_TEST_SUITE(TotalOrderTest, Types); @@ -1457,13 +1469,7 @@ TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { if constexpr (std::numeric_limits::has_infinity) { values.push_back(std::numeric_limits::infinity()); } - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); auto lhs = ConstantR1(&builder, values); auto rhs = ConstantR1( &builder, diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 9650077ed57b28..9e191a30b405ae 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -52,7 +52,13 @@ using FloatTypes = ::testing::Types; + tsl::float8_e5m2fnuz +#ifndef XLA_TEST_BACKEND_TPU + // TODO(b/385004399): Run tests on these types on TPU. + , + tsl::float4_e2m1fn, tsl::float8_e8m0fnu +#endif + >; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 4f06ea0cc290c7..a8e370ad50c0d3 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,9 +54,17 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -741,10 +749,11 @@ XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { XlaBuilder builder(this->TestName()); using FP = TypeParam; - auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.5}, FP{2.0}, FP{-0.0}}); ConvertElementType(a, PRED); - std::array expected = {false, true, true, false}; + bool zero_pred = !has_zero_v; + std::array expected = {zero_pred, true, true, zero_pred}; this->template ComputeAndCompareR1(&builder, expected, {}); } @@ -1925,5 +1934,283 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F4E2M1FN + +XLA_TEST_F(ConvertTest, DISABLED_ON_TPU(ConvertF16F4e2m1fnRoundtrip)) { + // Convert from FP16 to FP4, then back to FP16. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFCp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.8p-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.004p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FCp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f4 = + ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, + DISABLED_ON_TPU(DISABLED_ON_CPU(ConvertF32F4e2m1fnRoundtrip))) { + // Convert from FP32 to FP4, then back to FP32. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFFFFEp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.8p-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.000002p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FFFFEp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f4 = ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive)) { + // Convert from FP4 to supported floating point type, then back to FP4. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f4_as_fp = + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f4_as_fp, F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive2)) { + // Convert from supported floating point type to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF4e2m1fnRoundtripExhaustive3)) { + // Convert from FP4 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, + DISABLED_ON_TPU(ConvertF4e2m1fnF16RoundtripExhaustive4)) { + // Convert from (B)F16 to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +// ----- F8E8M0FNU + +XLA_TEST_F(ConvertTest, DISABLED_ON_TPU(ConvertF32F8e8m0fnuRoundtrip)) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, nan}, // No zero values + {-0.0, nan}, + {1.0, 1.0}, + {-1.0, nan}, // No negative values + {nan, nan}, + {inf, nan}, + // clang-format on + {0x1.8p1, 0x1p2}, // Round-to-even up + {0x1.8p2, 0x1p3}, // Round-to-even up (always rounds up) + {0x1p127, 0x1p127}, // Max value + {0x1.7FFFFEp127, 0x1p127}, // Largest number that doesn't overflow + {0x1.8p127, nan}, // Smallest number that overflows + {0x1.FFFFFEp127, nan}, // Overflow + {0x1p-126, 0x1p-126}, // Smallest F8 normal + {0x0.800002p-126, 0x1p-126}, // Smallest number rounding up to normal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E8M0FNU); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive)) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive2)) { + if (this->client_->platform()->Name() == "Host") { + // This test is disabled on CPU, as converting 0x1p-127 from double to float + // using CVTSD2SS on x64 results in an underflow (even though the result is + // representable as denormalized float32). + if (std::is_same_v) { + GTEST_SKIP() << "Skipping test for double precision floating point that " + "loses denormal value during conversion"; + } + } + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, + DISABLED_ON_TPU(ConvertF8e8m0fnuRoundtripExhaustive3)) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, + DISABLED_ON_TPU(ConvertF8e8m0fnuF16RoundtripExhaustive4)) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 7f0d9c4507a2a2..d1d6882b6532a5 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -121,6 +121,7 @@ enum PrimitiveType { F64, C64, C128, + F4E2M1FN, F8E5M2, F8E4M3, F8E4M3FN, @@ -128,17 +129,19 @@ enum PrimitiveType { F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, + F8E8M0FNU, }; const std::vector& primitive_strings() { static auto vec = new std::vector( - {"s1", "s2", "s4", "s8", - "s16", "s32", "s64", "u1", - "u2", "u4", "u8", "u16", - "u32", "u64", "f16", "bf16", - "f32", "f64", "c64", "c128", - "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz", "f8e3m4"}); + {"s1", "s2", "s4", "s8", + "s16", "s32", "s64", "u1", + "u2", "u4", "u8", "u16", + "u32", "u64", "f16", "bf16", + "f32", "f64", "c64", "c128", + "f4e2m1fn", "f8e3m4", "f8e4m3", "f8e4m3b11fnuz", + "f8e4m3fn", "f8e4m3fnuz", "f8e5m2", "f8e5m2fnuz", + "f8e8m0fnu"}); return *vec; } @@ -415,6 +418,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F64: return FillFloatT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -422,6 +426,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: @@ -475,6 +480,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F64: return DisplayT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -482,6 +488,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index 5aabbf28b4baa8..31c8fe56a78faa 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -70,13 +70,15 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 2ac31005c16629..4a6d8fff6f72cd 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -24,6 +24,8 @@ enum DataType { kInt64 = 12; kF8E4M3 = 13; kF8E3M4 = 14; + kF4E2M1FN = 15; + kF8E8M0FNU = 16; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index e2c5eb295c6b12..a986efb7cca963 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,8 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float4_e2m1fn = + py::dtype::from_args(ml_dtypes.attr("float4_e2m1fn")).num(); numpy_dtypes.float8_e3m4 = py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); numpy_dtypes.float8_e4m3 = @@ -75,6 +77,8 @@ struct MlDtypesInitInfo { py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")).num(); numpy_dtypes.float8_e5m2fnuz = py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")).num(); + numpy_dtypes.float8_e8m0fnu = + py::dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")).num(); numpy_dtypes.int4 = py::dtype::from_args(ml_dtypes.attr("int4")).num(); numpy_dtypes.uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")).num(); } catch (const std::exception& e) { @@ -85,6 +89,7 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float4_e2m1fn == NPY_NOTYPE || numpy_dtypes.float8_e3m4 == NPY_NOTYPE || numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || @@ -92,6 +97,7 @@ struct MlDtypesInitInfo { numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || numpy_dtypes.float8_e5m2 == NPY_NOTYPE || numpy_dtypes.float8_e5m2fnuz == NPY_NOTYPE || + numpy_dtypes.float8_e8m0fnu == NPY_NOTYPE || numpy_dtypes.int4 == NPY_NOTYPE || numpy_dtypes.uint4 == NPY_NOTYPE) { init_valid = false; } diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index b3aa94e430239a..725d844c27bb4e 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,7 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float4_e2m1fn; int float8_e3m4; int float8_e4m3; int float8_e4m3fn; @@ -31,6 +32,7 @@ struct NumpyDtypes { int float8_e4m3fnuz; int float8_e5m2; int float8_e5m2fnuz; + int float8_e8m0fnu; int int4; int uint4; }; diff --git a/xla/types.h b/xla/types.h index 98e3d7c9331ffc..b702404601dae7 100644 --- a/xla/types.h +++ b/xla/types.h @@ -131,16 +131,32 @@ struct make_specialized_signed>> { template using make_specialized_signed_t = typename make_specialized_signed::type; +// has_negative_zero[_v] + template struct has_negative_zero : std::bool_constant::is_iec559> {}; +template <> +struct has_negative_zero : std::bool_constant {}; + template <> struct has_negative_zero : std::bool_constant {}; template inline constexpr bool has_negative_zero_v = has_negative_zero::value; +// has_zero[_v] + +template +struct has_zero : std::bool_constant {}; + +template <> +struct has_zero : std::bool_constant {}; + +template +inline constexpr bool has_zero_v = has_zero::value; + } // namespace xla #endif // XLA_TYPES_H_ diff --git a/xla/util.cc b/xla/util.cc index c18435a04c64bf..023e09342f113b 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -148,6 +148,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { + static_assert(std::numeric_limits::has_quiet_NaN); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, @@ -174,6 +175,10 @@ static std::string GenericRoundTripFpToString(FloatT value) { static_cast(value)); } +std::string RoundTripFpToString(tsl::float4_e2m1fn value) { + return GenericRoundTripFpToString(value); +} + std::string RoundTripFpToString(tsl::float8_e5m2 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); @@ -212,6 +217,11 @@ std::string RoundTripFpToString(tsl::float8_e3m4 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e8m0fnu value) { + std::string result = GenericRoundTripFpToString(value); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index 959009073e96f9..a4578709392445 100644 --- a/xla/util.h +++ b/xla/util.h @@ -416,6 +416,9 @@ std::string VectorString(const std::initializer_list& c) { return VectorString>(c); } +// Returns a string which can losslessly round trip to a float4 E2M1FN. +std::string RoundTripFpToString(tsl::float4_e2m1fn value); + // Returns a string which can losslessly round trip to a float8 E5M2. std::string RoundTripFpToString(tsl::float8_e5m2 value); @@ -437,6 +440,9 @@ std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); // Returns a string which can losslessly round trip to a float8 E3M4. std::string RoundTripFpToString(tsl::float8_e3m4 value); +// Returns a string which can losslessly round trip to a float8 E8M0FNU. +std::string RoundTripFpToString(tsl::float8_e8m0fnu value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -652,8 +658,9 @@ template auto SignAndMagnitude(T x) { using BitType = UnsignedIntegerTypeForSizeType; BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); - const BitType x_bits = Eigen::numext::bit_cast(x); - const BitType x_sign = x_bits ^ x_abs_bits; + // Eigen implements the sign value to be either all-zeros (for positive input) + // or all-ones (for negative input). + BitType x_sign = Eigen::numext::bit_cast(Eigen::numext::signbit(x)); if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. @@ -664,12 +671,17 @@ auto SignAndMagnitude(T x) { return std::make_pair(x_sign, x_abs_bits); } +template <> +inline auto SignAndMagnitude(tsl::float8_e8m0fnu x) { + uint8_t x_bits = Eigen::numext::bit_cast(x); + return std::make_pair(static_cast(0), x_bits); +} + template auto SignAndMagnitudeToTwosComplement(T sign, T magnitude) { static_assert(!std::numeric_limits::is_signed); using SignedType = std::make_signed_t; - return static_cast(magnitude) ^ - (static_cast(sign) < 0 ? SignedType{-1} : SignedType{0}); + return static_cast(magnitude) ^ static_cast(sign); } // Returns the signed magnitude of T. @@ -679,6 +691,11 @@ auto ToSignMagnitude(T input) { return SignAndMagnitudeToTwosComplement(sign, magnitude); } +template <> +inline auto ToSignMagnitude(tsl::float8_e8m0fnu input) { + return Eigen::numext::bit_cast(input); +} + template constexpr int NanPayloadBits() { // Floating point types with signaling NaNs have payloads. diff --git a/xla/util_test.cc b/xla/util_test.cc index d15329872d911b..aa0e7fcd02a270 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -210,9 +210,9 @@ namespace { template void TotalOrderHelper(T x, T y) { auto x_sm = ToSignMagnitude(x); - bool x_sign = static_cast(Eigen::numext::signbit(x)); - bool y_sign = static_cast(Eigen::numext::signbit(y)); auto y_sm = ToSignMagnitude(y); + bool x_sign = static_cast(SignAndMagnitude(x).first); + bool y_sign = static_cast(SignAndMagnitude(y).first); if (x_sign && !y_sign) { EXPECT_LT(x_sm, y_sm) << x << " " << y; } @@ -243,6 +243,18 @@ void TotalOrderHelper(T x, T y) { } } // namespace +TEST(UtilTest, TotalOrder_F4E2M1FN) { + for (int a = 0; a < 16; ++a) { + tsl::float4_e2m1fn x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 16; ++b) { + tsl::float4_e2m1fn y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E5M2) { for (int a = 0; a < 256; ++a) { tsl::float8_e5m2 x = @@ -329,6 +341,18 @@ TEST(UtilTest, TotalOrder_F8E3M4) { } } +TEST(UtilTest, TotalOrder_F8E8M0FNU) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e8m0fnu x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e8m0fnu y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 01a6415549b584..a225a4cd945f10 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -111,6 +111,17 @@ enum PrimitiveType { F8E5M2FNUZ = 24; F8E4M3FNUZ = 25; + // MX float dtypes, as described in: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + // + // F4E2M1FN has 2 exponent bits and 1 mantissa bit. + // F8E8M0FNU has 8 exponent bits, no mantissa and no sign. + // + // Only finite values are supported (hence "FN" suffix). Unlike IEEE types, + // infinities and NaNs are not supported. + F4E2M1FN = 32; + F8E8M0FNU = 33; + // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. C128 = 18; // Paired F64 (real, imag), as in std::complex. @@ -136,7 +147,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 32 + // Next = 34 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -584,15 +595,17 @@ message LiteralProto { bytes bf16s = 13; bytes u16s = 16; bytes s16s = 17; - bytes f8e5m2s = 19; - bytes f8e4m3s = 28; - bytes f8e4m3fns = 20; + bytes f4e2m1fns = 32; + bytes f8e3m4s = 29; bytes f8e4m3b11fnuzs = 23; - bytes f8e5m2fnuzs = 24; + bytes f8e4m3fns = 20; bytes f8e4m3fnuzs = 25; - bytes f8e3m4s = 29; + bytes f8e4m3s = 28; + bytes f8e5m2fnuzs = 24; + bytes f8e5m2s = 19; + bytes f8e8m0fnus = 33; repeated int64 sparse_indices = 14; - // Next = 32 + // Next = 34 } message WindowDimension {