Skip to content

Commit

Permalink
PR #21380: Add F4E2M1FN and F8E8M0FNU types
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#21380

Previous PR openxla/xla#19096 was rolled back, re-trying.

This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented.

This will enable using microscaling (MX) formats ([RFC](openxla/xla#18085)), such as MXFP4.

```c
F4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5

F8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- openxla/stablehlo#2582
- jax-ml/ml_dtypes#181
- llvm/llvm-project#95392
- llvm/llvm-project#108877
- jax-ml/ml_dtypes#166
- llvm/llvm-project#107127
- llvm/llvm-project#111028
Copybara import of the project:

--
d7e00c49a4b4f26c06266d6bb941275e67464c01 by Sergey Kozub <skozub@nvidia.com>:

Add F4E2M1FN and F8E8M0FNU types

Merging this change closes #21380

PiperOrigin-RevId: 715434229
  • Loading branch information
sergey-kozub authored and TensorFlow MLIR Team committed Jan 14, 2025
1 parent 4484fd7 commit fd28e84
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b

// -----

func.func @f4e2m1fn(%arg0: tensor<f16>) -> tensor<f4E2M1FN> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f4E2M1FN>
func.return %0 : tensor<f4E2M1FN>
}

// -----

func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
func.return %0 : tensor<f8E3M4>
Expand Down Expand Up @@ -6872,6 +6879,13 @@ func.func @f8e5m2(%arg0: tensor<f16>) -> tensor<f8E5M2> {

// -----

func.func @f8e8m0fnu(%arg0: tensor<f16>) -> tensor<f8E8M0FNU> {
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E8M0FNU>
func.return %0 : tensor<f8E8M0FNU>
}

// -----

func.func @top_k_1d(%arg0 : tensor<16xf32>) {
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
Expand Down

0 comments on commit fd28e84

Please sign in to comment.