From d7d5af74c5f8a94522779a121c0a4a962156fb64 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Wed, 6 Nov 2024 09:33:32 +0100 Subject: [PATCH] Add F4E2M1FN type: add tests --- xla/array2d_test.cc | 14 +++++ .../gpu/codegen/transforms/lower_tensors.cc | 26 +++++---- xla/fp_util_test.cc | 53 +++++++++++++++++++ xla/hlo/builder/lib/math.cc | 11 ++-- xla/hlo/builder/lib/math_test.cc | 34 ++++++++---- .../simplifiers/float_normalization.cc | 3 ++ .../simplifiers/float_normalization_test.cc | 4 +- xla/mlir/utils/type_util.cc | 6 ++- xla/mlir/utils/type_util_test.cc | 1 + xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 7 +++ xla/python/ifrt/dtype_test.cc | 45 ++++++---------- xla/service/cpu/cpu_compiler.cc | 2 + xla/service/cpu/onednn_memory_util.h | 2 +- xla/service/gpu/gpu_compiler.cc | 2 + xla/service/llvm_ir/llvm_util.cc | 2 + xla/stream_executor/data_type.h | 4 ++ xla/stream_executor/dnn.cc | 1 + xla/stream_executor/gpu/gpu_blas_lt.cc | 5 ++ xla/stream_executor/rocm/hip_blas_utils.cc | 5 +- xla/tests/array_elementwise_ops_test.cc | 3 +- xla/tests/constants_test.cc | 8 +-- xla/tools/driver.cc | 14 ++--- xla/tsl/framework/type_traits.h | 1 + xla/tsl/protobuf/dnn.proto | 1 + xla/types.h | 2 +- 25 files changed, 186 insertions(+), 70 deletions(-) diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 921da30256fa3..8c314edff2eb6 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -219,6 +219,20 @@ TEST(Array2dTest, LinspaceF8E3M4) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF4E2M1FN) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/backends/gpu/codegen/transforms/lower_tensors.cc b/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 38e3671f9613f..dde0a59ae406c 100644 --- a/xla/backends/gpu/codegen/transforms/lower_tensors.cc +++ b/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -297,7 +297,7 @@ std::tuple GetI4IndexAndNibble(Value linear_index, mlir::LLVM::GEPOp CreateGep(TypedValue tensor, Value linear_index, mlir::ImplicitLocOpBuilder& b) { Type element_type = tensor.getType().getElementType(); - if (element_type == b.getI4Type()) { + if (element_type.getIntOrFloatBitWidth() == 4) { element_type = b.getI8Type(); } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); @@ -325,8 +325,9 @@ struct RewriteTensorExtract : OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); + Value is_low_nibble = nullptr; - if (element_type == rewriter.getI4Type()) { + if (element_type.getIntOrFloatBitWidth() == 4) { std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); } @@ -341,7 +342,7 @@ struct RewriteTensorExtract : OpRewritePattern { auto high_value = b.create( load, b.create(4, load.getType())); load = b.create( - op.getType(), + rewriter.getI4Type(), b.create(is_low_nibble, load, high_value)); } @@ -377,6 +378,7 @@ struct RewriteTransferRead : OpRewritePattern { auto source = mlir::dyn_cast>( op.getSource()); + mlir::Type source_element_type = source.getType().getElementType(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); @@ -385,7 +387,8 @@ struct RewriteTransferRead : OpRewritePattern { if (vector_type.getElementType().isInteger(1)) { vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type()); } - if (op.getVectorType().getElementType().isInteger(4)) { + mlir::Type gep_element_type = vector_type.getElementType(); + if (gep_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -397,11 +400,11 @@ struct RewriteTransferRead : OpRewritePattern { auto loaded = b.create(llvm_vector_type, gep).getResult(); - if (source.getType().getElementType().isInteger(1)) { + if (source_element_type.isInteger(1)) { Value zero = b.create( mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0))); loaded = b.create(arith::CmpIPredicate::ne, loaded, zero); - } else if (source.getType().getElementType().isInteger(4)) { + } else if (source_element_type.getIntOrFloatBitWidth() == 4) { // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the // elements. loaded = PermutePairsInVector(loaded, b); @@ -430,7 +433,7 @@ struct RewriteTensorInsert : OpRewritePattern { auto scalar_value = op.getScalar(); // For i4 we store 2 values into one byte. This needs special handling here. - if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) { + if (tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) { // We need to use directly op.getDest() as input, otherwise the following // rewrite might remove the only user of it. tensor_dest = op.getDest(); @@ -448,6 +451,10 @@ struct RewriteTensorInsert : OpRewritePattern { auto tensor_dest_i8 = b.create(tensor_ty, tensor_dest) .getResult(0); + if (scalar_value.getType() != rewriter.getI4Type()) { + scalar_value = + b.create(rewriter.getI4Type(), scalar_value); + } scalar_value = b.create(ty, scalar_value); // We need AtomicRMWOp because it can happen that different threads try to @@ -507,12 +514,13 @@ struct RewriteTransferWrite : OpRewritePattern { auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::Value vector_value = op.getVector(); - if (op.getVectorType().getElementType().isInteger(1)) { + mlir::Type vector_element_type = op.getVectorType().getElementType(); + if (vector_element_type.isInteger(1)) { vector_value = b.create( op.getVectorType().cloneWith(std::nullopt, b.getI8Type()), vector_value); } - if (op.getVectorType().getElementType().isInteger(4)) { + if (vector_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 3eb7c54f919b0..09b66c2984b05 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -119,6 +119,59 @@ class FP8E4M3DistanceTest : public ::testing::Test {}; using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F4E2M1FNDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)), + 1); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)), + 2); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 4); +} + TEST(FPDistanceTest, F8E3M4Distance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index f2a77df3d7dda..620e907f8cf11 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F4E2M1FN: case F8E3M4: case F8E4M3: case F8E5M2: @@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index 9755643b7586a..84df4f2993b91 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -95,9 +95,13 @@ class MathTypedTest : public MathTest { Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); bool has_inf = std::numeric_limits::has_infinity; + bool has_nan = std::numeric_limits::has_quiet_NaN; + bool is_finite = !has_inf && !has_nan; + bool is_nan_only = !has_inf && has_nan; + auto expected = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1( - {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1({true, true, true, true, true, is_finite, + is_finite, is_finite, is_finite}), LiteralUtil::CreateR1({false, false, false, false, false, has_inf, has_inf, false, false}), LiteralUtil::CreateR1( @@ -105,7 +109,8 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1( {false, false, false, false, false, false, has_inf, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, - !has_inf, !has_inf, true, true})); + is_nan_only, is_nan_only, has_nan, + has_nan})); ComputeAndCompareLiteral(&b, expected, {}); } @@ -118,10 +123,11 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), &b)); + bool is_mx = std::is_same_v; ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( - {has_negative_zero_v, false, false, false, false, false, false}), + {has_negative_zero_v, false, false, false, false, false, is_mx}), {}, error_spec_); } @@ -136,6 +142,9 @@ class MathTypedTest : public MathTest { // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { SetFastMathDisabled(true); + if (std::is_same_v) { + GTEST_SKIP() << "Skipping due to low precision"; + } // Tests disable constant folding by default, but this test needs it // enabled, otherwise we don't tickle the bug we're trying to catch. @@ -181,18 +190,23 @@ class MathTypedTest : public MathTest { &b); Erf(x); - bool has_inf = std::numeric_limits::has_infinity; - std::vector expected = { - has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)}; + bool inf_as_nan = !std::numeric_limits::has_infinity && + std::numeric_limits::has_quiet_NaN; + std::vector expected = {inf_as_nan ? nan : T(-1), + inf_as_nan ? nan : T(1), + T(-0), + T(0), + T(-1), + T(1)}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } }; using TestTypes = - ::testing::Typesmutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) { if (subshape->element_type() == from) { subshape->set_element_type(to); + if (subshape->has_layout() && from == F4E2M1FN) { + subshape->mutable_layout()->set_element_size_in_bits(0); + } } }); float_normalization_->UpdateLayout(hlo->mutable_shape()); diff --git a/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/xla/hlo/transforms/simplifiers/float_normalization_test.cc index 86ec889abc652..9017b6dc47b61 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -150,7 +150,9 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); + ::testing::Values(F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, + F8E5M2, F8E5M2FNUZ)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 2581390a1e13d..1b7ab6d08a9ac 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -32,6 +32,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( switch (type) { case xla::PrimitiveType::PRED: return b.getI1Type(); + case xla::PrimitiveType::F4E2M1FN: + return b.getFloat4E2M1FNType(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); case xla::PrimitiveType::F8E4M3: @@ -78,7 +80,9 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( } xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { - if (type.isFloat8E5M2()) { + if (type.isFloat4E2M1FN()) { + return xla::PrimitiveType::F4E2M1FN; + } else if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; } else if (type.isFloat8E4M3()) { return xla::PrimitiveType::F8E4M3; diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index a8043ab0b5f14..5f9d21e14cac1 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P( Execute, TypeUtilTest, ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, + {F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index d07a178c6c4e7..d5fde8a5c7305 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e3m4(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index e6b730672621d..384117353cf27 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -66,36 +66,21 @@ TEST(DTypeTest, ByteSize) { TEST(DTypeTest, BitSize) { for (const auto& [kind, bit_size] : std::vector>({ - {DType::kS2, 2}, - {DType::kU2, 2}, - {DType::kS4, 4}, - {DType::kU4, 4}, - {DType::kPred, 8}, - {DType::kS8, 8}, - {DType::kU8, 8}, - {DType::kF4E2M1FN, 4}, - {DType::kF8E3M4, 8}, - {DType::kF8E4M3, 8}, - {DType::kF8E4M3FN, 8}, - {DType::kF8E4M3B11FNUZ, 8}, - {DType::kF8E4M3FNUZ, 8}, - {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, - {DType::kS16, 16}, - {DType::kU16, 16}, - {DType::kF16, 16}, - {DType::kBF16, 16}, - {DType::kS32, 32}, - {DType::kU32, 32}, - {DType::kF32, 32}, - {DType::kS64, 64}, - {DType::kU64, 64}, - {DType::kF64, 64}, - {DType::kC64, 64}, - {DType::kC128, 128}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, - {DType::kString, -1}, + {DType::kS2, 2}, {DType::kU2, 2}, + {DType::kS4, 4}, {DType::kU4, 4}, + {DType::kPred, 8}, {DType::kS8, 8}, + {DType::kU8, 8}, {DType::kF4E2M1FN, 4}, + {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, + {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, + {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, + {DType::kF8E5M2FNUZ, 8}, {DType::kS16, 16}, + {DType::kU16, 16}, {DType::kF16, 16}, + {DType::kBF16, 16}, {DType::kS32, 32}, + {DType::kU32, 32}, {DType::kF32, 32}, + {DType::kS64, 64}, {DType::kU64, 64}, + {DType::kF64, 64}, {DType::kC64, 64}, + {DType::kC128, 128}, {DType::kToken, -1}, + {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).bit_size(), bit_size == -1 ? std::nullopt : std::make_optional(bit_size)); diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 3ffb34ecedbc4..bc64f6dcb419e 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -602,6 +602,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&s4_support); FloatSupport u4_support(U4, U8); pipeline.AddPass(&u4_support); + FloatSupport f4e2m1fn_support(F4E2M1FN, F16); + pipeline.AddPass(&f4e2m1fn_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 18841d2712dcb..2a9fa25ca185b 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ, F8E4M3, F8E3M4 + // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN default: return dt::undef; } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 9c9467546f1a3..6c45121ae1ff5 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1491,6 +1491,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); const GpuFloatSupport s4_support(gpu_version, S4, S8); const GpuFloatSupport u4_support(gpu_version, U4, U8); + const GpuFloatSupport f4e2m1fn_support(gpu_version, F4E2M1FN, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1504,6 +1505,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e3m4_support); sub_pipeline.AddPass(&s4_support); sub_pipeline.AddPass(&u4_support); + sub_pipeline.AddPass(&f4e2m1fn_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index d56172dd4b254..c4be8818f5ff6 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -199,6 +199,8 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case S16: case U16: return llvm::Type::getInt16Ty(context); + case F4E2M1FN: + return llvm::Type::getIntNTy(context, 4); case F8E5M2: case F8E5M2FNUZ: case F8E4M3: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index f5246389e485c..dc77c0f711a45 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,10 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF4E2M1FN; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E3M4; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 6b7a87d80b3ae..5d41b3152790b 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -69,6 +69,7 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6aee86bf2cbc1..41db3c6c90104 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -56,6 +56,8 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E4M3FNUZ; case PrimitiveType::F8E3M4: return DataType::kF8E3M4; + case PrimitiveType::F4E2M1FN: + return DataType::kF4E2M1FN; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -93,6 +95,8 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E4M3FNUZ; case DataType::kF8E3M4: return PrimitiveType::F8E3M4; + case DataType::kF4E2M1FN: + return PrimitiveType::F4E2M1FN; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -154,6 +158,7 @@ absl::StatusOr GetBlasComputationType( case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through case PrimitiveType::F8E3M4: // fall-through + case PrimitiveType::F4E2M1FN: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index e5730121addd8..dd1cbcf2ea29c 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -39,8 +39,9 @@ hipDataType AsHipblasDataType(blas::DataType type) { case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: - LOG(FATAL) - << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; + case blas::DataType::kF4E2M1FN: + LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " + "F8E3M4 and F4E2M1FN"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index c12ce79a06e8f..1054c72fea658 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -1423,7 +1423,8 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types class ConstantsFloatTest : public ConstantsTest {}; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 7f0d9c4507a2a..2db53e65cf75c 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -121,6 +121,7 @@ enum PrimitiveType { F64, C64, C128, + F4E2M1FN, F8E5M2, F8E4M3, F8E4M3FN, @@ -132,12 +133,11 @@ enum PrimitiveType { const std::vector& primitive_strings() { static auto vec = new std::vector( - {"s1", "s2", "s4", "s8", - "s16", "s32", "s64", "u1", - "u2", "u4", "u8", "u16", - "u32", "u64", "f16", "bf16", - "f32", "f64", "c64", "c128", - "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", + {"s1", "s2", "s4", "s8", "s16", + "s32", "s64", "u1", "u2", "u4", + "u8", "u16", "u32", "u64", "f16", + "bf16", "f32", "f64", "c64", "c128", + "f4e2m1fn", "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", "f8e5m2fnuz", "f8e4m3fnuz", "f8e3m4"}); return *vec; } @@ -415,6 +415,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F64: return FillFloatT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -475,6 +476,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F64: return DisplayT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index f7a9bd7a54bc9..0a0deb5f5f66b 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -70,6 +70,7 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 2ac31005c1662..09237a09e45b8 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -24,6 +24,7 @@ enum DataType { kInt64 = 12; kF8E4M3 = 13; kF8E3M4 = 14; + kF4E2M1FN = 15; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/types.h b/xla/types.h index 58728d44dfcb4..f851ea5a66d59 100644 --- a/xla/types.h +++ b/xla/types.h @@ -23,7 +23,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "Eigen/Core" // IWYU pragma: export +#include "Eigen/Core" // IWYU pragma: export #include "tsl/platform/ml_dtypes.h" // IWYU pragma: export namespace xla {