Skip to content

Commit

Permalink
PR #21380: Add F4E2M1FN and F8E8M0FNU types
Browse files Browse the repository at this point in the history
Imported from GitHub PR #21380

Previous PR #19096 was rolled back, re-trying.

This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented.

This will enable using microscaling (MX) formats ([RFC](#18085)), such as MXFP4.

```c
F4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5

F8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- openxla/stablehlo#2582
- jax-ml/ml_dtypes#181
- llvm/llvm-project#95392
- llvm/llvm-project#108877
- jax-ml/ml_dtypes#166
- llvm/llvm-project#107127
- llvm/llvm-project#111028
Copybara import of the project:

--
d7e00c4 by Sergey Kozub <skozub@nvidia.com>:

Add F4E2M1FN and F8E8M0FNU types

Merging this change closes #21380

FUTURE_COPYBARA_INTEGRATE_REVIEW=#21380 from openxla:skozub/e2m1_e8m0 d7e00c4
PiperOrigin-RevId: 715070992
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Jan 14, 2025
1 parent 7d912e5 commit a4074ba
Show file tree
Hide file tree
Showing 79 changed files with 1,853 additions and 377 deletions.
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF4E2M1FN) {
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
191 changes: 119 additions & 72 deletions xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Large diffs are not rendered by default.

59 changes: 35 additions & 24 deletions xla/backends/gpu/codegen/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
Type element_type = tensor.getType().getElementType();
if (element_type == b.getI4Type()) {
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
element_type = b.getI8Type();
}
auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext());
Expand Down Expand Up @@ -326,7 +327,8 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);
Type element_type = op.getTensor().getType().getElementType();
Value is_low_nibble = nullptr;
if (element_type == rewriter.getI4Type()) {
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
std::tie(linear_index, is_low_nibble) =
GetI4IndexAndNibble(linear_index, b);
}
Expand All @@ -341,7 +343,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto high_value = b.create<mlir::arith::ShRUIOp>(
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
load = b.create<mlir::arith::TruncIOp>(
op.getType(),
rewriter.getI4Type(),
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
}

Expand Down Expand Up @@ -377,6 +379,7 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {

auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
op.getSource());
mlir::Type source_element_type = source.getType().getElementType();

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto linear_index = GetLinearIndex(op.getIndices(), b);
Expand All @@ -385,7 +388,9 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
if (vector_type.getElementType().isInteger(1)) {
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
}
if (op.getVectorType().getElementType().isInteger(4)) {
mlir::Type gep_element_type = vector_type.getElementType();
if (gep_element_type.isIntOrFloat() &&
gep_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand All @@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
auto loaded =
b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();

if (source.getType().getElementType().isInteger(1)) {
if (source_element_type.isInteger(1)) {
Value zero = b.create<mlir::arith::ConstantOp>(
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
} else if (source.getType().getElementType().isInteger(4)) {
} else if (source_element_type.isIntOrFloat() &&
source_element_type.getIntOrFloatBitWidth() == 4) {
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
// elements.
loaded = PermutePairsInVector(loaded, b);
Expand Down Expand Up @@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto scalar_value = op.getScalar();

// For i4 we store 2 values into one byte. This needs special handling here.
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
if (tensor_dest.getType().getElementType().isIntOrFloat() &&
tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
// We need to use directly op.getDest() as input, otherwise the following
// rewrite might remove the only user of it.
tensor_dest = op.getDest();
Expand All @@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto tensor_dest_i8 =
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
.getResult(0);
if (scalar_value.getType() != rewriter.getI4Type()) {
scalar_value =
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
}
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);

// We need AtomicRMWOp because it can happen that different threads try to
Expand Down Expand Up @@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);

mlir::Value vector_value = op.getVector();
if (op.getVectorType().getElementType().isInteger(1)) {
mlir::Type vector_element_type = op.getVectorType().getElementType();
if (vector_element_type.isInteger(1)) {
vector_value = b.create<arith::ExtUIOp>(
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
vector_value);
}
if (op.getVectorType().getElementType().isInteger(4)) {
if (vector_element_type.isIntOrFloat() &&
vector_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand Down Expand Up @@ -577,21 +590,19 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
// Needed to support complex element type.
mlir::LLVMTypeConverter converter(b.getContext());
auto llvm_element_type = converter.convertType(element_type);
if (mlir::isa<mlir::IntegerType>(element_type)) {
int bit_width = mlir::cast<mlir::IntegerType>(element_type).getWidth();
if (bit_width == 4) {
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
llvm_element_type = b.getI8Type();
auto unpacked_data =
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
std::vector<char> packed_data(num_elements);
absl::Span<char> packed_data_span =
absl::MakeSpan(packed_data.data(), packed_data.size());
PackIntN(4, unpacked_data, packed_data_span);
value = mlir::DenseElementsAttr::getFromRawBuffer(
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
packed_data);
}
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
llvm_element_type = b.getI8Type();
auto unpacked_data =
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
std::vector<char> packed_data(num_elements);
absl::Span<char> packed_data_span =
absl::MakeSpan(packed_data.data(), packed_data.size());
PackIntN(4, unpacked_data, packed_data_span);
value = mlir::DenseElementsAttr::getFromRawBuffer(
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
packed_data);
}
auto array_ty =
mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements);
Expand Down
50 changes: 50 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,53 @@ module {
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
// CHECK: arith.truncf %[[EXT]] : f32 to f16
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
%ret = arith.extf %arg0 : f4E2M1FN to f16
return %ret : f16
}
}

// CHECK-LABEL: @f4_to_f16
// CHECK-NOT: arith.extf

// -----

module {
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f16_to_f4
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
%ret = math.absf %arg0 : f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f4_abs
// CHECK-NOT: math.absf
// CHECK: arith.constant 7 : i4

// -----

module {
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
%ret = math.absf %arg0 : f8E8M0FNU
return %ret : f8E8M0FNU
}
}

// CHECK-LABEL: @e8m0_abs
// CHECK-NOT: math.absf
// CHECK: return %arg0
42 changes: 41 additions & 1 deletion xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -763,4 +763,44 @@ func.func @for_op(%arg0: tensor<500xf32>) -> f32 {

// CHECK-LABEL: @for_op
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) {

// -----

func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN {
%cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN>
%extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN>
%extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN>
%0 = arith.addf %extracted, %extracted_0 : f4E2M1FN
return %0 : f4E2M1FN
}
// CHECK: llvm.mlir.global private constant
// CHECK-SAME: dense<[25, 64]>
// CHECK-LABEL: @f4_constant

// -----

func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> {
%c16 = arith.constant 16 : index
%c0 = arith.constant 0.0 : f4E2M1FN
%out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN>
func.return %out : vector<2xf4E2M1FN>
}
// CHECK-LABEL: @transfer_read_f4
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8]
// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4>
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN>
// CHECK: return %[[OUT]] : vector<2xf4E2M1FN>

// -----

func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1},
%arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> {
%c10 = arith.constant 10 : index
%out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN>
func.return %out : tensor<43xf4E2M1FN>
}
// CHECK-LABEL: @transfer_write_f4
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4>
// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr
9 changes: 7 additions & 2 deletions xla/comparison_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,13 @@ class Comparison {
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
// Reference:
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
if constexpr (std::numeric_limits<T>::is_signed) {
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
} else {
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
}
}
}
// Applies the comparison from this Comparison's direction and ordering.
Expand Down
4 changes: 4 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "C128";
case XLA_FFI_DataType_TOKEN:
return os << "TOKEN";
case XLA_FFI_DataType_F4E2M1FN:
return os << "F4E2M1FN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E3M4:
Expand All @@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "F8E5M2FNUZ";
case XLA_FFI_DataType_F8E4M3FNUZ:
return os << "F8E4M3FNUZ";
case XLA_FFI_DataType_F8E8M0FNU:
return os << "F8E8M0FNU";
}
}

Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ typedef enum {
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
XLA_FFI_DataType_F8E4M3FNUZ = 25,
XLA_FFI_DataType_F4E2M1FN = 32,
XLA_FFI_DataType_F8E8M0FNU = 33,
} XLA_FFI_DataType;
// LINT.ThenChange(ffi_test.cc)

Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ enum class DataType : uint8_t {
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand Down Expand Up @@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::F8E5M2FNUZ:
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
case DataType::F4E2M1FN:
case DataType::F8E8M0FNU:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) {

EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN));
EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
Expand All @@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU));
}

TEST(FfiTest, DataTypeByteWidth) {
Expand Down Expand Up @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128),
ByteWidth(DataType::C128));

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN),
ByteWidth(DataType::F4E2M1FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
Expand All @@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) {
ByteWidth(DataType::F8E4M3FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
ByteWidth(DataType::F8E3M4));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU),
ByteWidth(DataType::F8E8M0FNU));
}

TEST(FfiTest, ErrorEnumValue) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C64:
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F4E2M1FN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
case PrimitiveType::F8E4M3FNUZ:
case PrimitiveType::F8E3M4:
case PrimitiveType::F8E8M0FNU:
return static_cast<XLA_FFI_DataType>(primitive_type);
default:
DCHECK(false) << "Unsupported primitive type "
Expand Down
Loading

0 comments on commit a4074ba

Please sign in to comment.