Skip to content

Microscaling Data Formats in MLIR

Renato Golin edited this page Jul 31, 2024 · 16 revisions

Introduction

https://arxiv.org/pdf/2310.10537

Basic type

  mx_tensor< <M' x N' x ... x ScaTy>, [ <M x N x ... x SubTn>, ... ] x Type >

  // Where:
  //  * M , N , ... are the tensor dimensions (if more than one sub-tensor, must all be the same)
  //  * M', N', ... are the scaling factor's dimensions, from a single scalar to M, N, ...
  //    * If scaling factor is 'blocked' by factors m, n, ..., then: M' = M/m, N' = N/n, ...
  //  * ScaTy is the element type of the scaling factor tensor (usually i8)
  //  * Type is the element type of the data (int8, fp4, fp6, fp8, bf16, etc)
  //  * SubTn is a sub-type of Type, if the data is stored separately (ex. aligned access)
  //    * Type can be omitted if there's only one data tensor
  //    * Needs to be present if more than one due to int/fp casts when necessary (ex. i2 + i4 -> fp6)

Examples

  // bf16 data type, single scaling factor
  mx_tensor< <1x1xi8>, <1024x1024xbf16> >

  // bf16 data type, tiled scaling factor (different type)
  mx_tensor< <4x4xf16>, <1024x1024xbf16> > // One scaling factor per 256x256 elements

  // fp8 data type, column-wise scaling factor
  mx_tensor< <1024x1xi8>, <1024x1024xfp8> >

  // fp6 data type as two tensors, column-wise tiled scaling factor (one per 32 columns)
  mx_tensor< <32x1xi8>, <1024x1024xi2>, <1024x1024xi4> x fp6 >

  // syntax variation on the form above
  mx_tensor< <32x1xi8>, [ <1024x1024xi2>, <1024x1024xi4> ] x fp6 >

  // syntax variation on the form above
  mx_tensor< <32x1xi8>, < <1024x1024xi2> x <1024x1024xi4> x fp6 > >

Note: fpN here is a generic N-bit float which can be any of the possible encodings. Here, just for illustration purposes.

The last example needs special considerations. It assumed the semantics is:

  • Concat element-wise all tensors on load
  • Cast to the compute type before the operation
  • Split element-wise on write (will need different pointers on memref semantics)
  • Cast to the storage type before write

Conversion from Tensor

Converts from already quantized tensors into mx_tensor. Some external process must have created them in their right shapes before. The next section has some ideas for a generic quantization process.

Default behaviour

  // Constructing an MX tensor
  %scale = tensor.empty(): tensor<4x4xi8>
  %data = tensor.empty() : tensor<1024x1024xfp8>
  %0 = mx_tensor.construct %scale, %data : mx_tensor< <4x4xi8>, <1024x1024xfp8> >

  // Destructing an MX tensor
  %scale1, %data1 = mx_tensor.deconstruct %0 : tensor<4x4xi8>, tensor<1024x1024xfp8>

Assumes:

  • Data type remains the same
  • Single data tensor

Split/Concat behaviour

  // Constructing an MX tensor from multiple data tensors
  %scale = tensor.empty(): tensor<4x4xi8>
  %data0 = tensor.empty() : tensor<1024x1024xi2>
  %data1 = tensor.empty() : tensor<1024x1024xi4>
  %0 = mx_tensor.construct %scale, %data { compute_type = fp6 } : mx_tensor< <4x4xi8>, <1024x1024xi2>, <1024x1024xi4> x fp6 >

  // Destructing an MX tensor into multiple data tensors
  %scale1, %data2, %data3 = mx_tensor.deconstruct { storage_types = [i2, i4] } %0 : tensor<4x4xi8>, tensor<1024x1024xi2>, tensor<1024x1024xi4>

Assumes:

  • Two different attributes: compute type and storage type.
  • If not specified, storage types must match compute type's bit-width only (2+4=6)
  • FP cast could be implicit when compute/storage types have different families. Or it could be an explicit attribute.

Quantization

Converts a regular tensor into an mx_tensor by some quantization algorithm, using a specified blocking factor.

Default behaviour

  // Quantize a tensor
  %data = tensor.empty() : tensor<1024x1024xf32>
  %0 = mx_tensor.quantize %data
      { block_factor = [ 256, 256 ],
        scale_type = i8,
        compute_type = fp8, // storage_type = compute_type if unspecified
        algorithm = { ... } // TODO: Named algorithms? Formulas? Is there a default one?
      }
      : mx_tensor< <4x4xi8>, <1024x1024xfp8> >

  // Dequantize an MX tensor
  %data = mx_tensor.dequantize %0 : tensor<4x4xi8>, tensor<1024x1024xfp8>

Dequantization should be simpler because we already have all the information we need from the mx_tensor type. It can be lowered into a linalg.generic with casts and multiplications.

Split/Concat behaviour

  // Constructing an MX tensor from multiple data tensors
  %data = tensor.empty() : tensor<1024x1024xf32>
  %0 = mx_tensor.quantize %data
      { block_factor = [ 256, 256 ],
        scale_type = i8,
        compute_type = fp8,
        storage_types = [ i2, i4 ], // Must add up bit-width to compute_type
        algorithm = { ... } // TODO: Named algorithms? Formulas? Is there a default one?
      }
      : mx_tensor< <4x4xi8>, <1024x1024xi2>, <1024x1024xi4> x fp6 >

Use in linear algebra

  // Can be tiled with simple rules
  %a = tensor.empty() : tensor<1024x1024xbf16>
  %b = mx_tensor.empty() : mx_tensor< <4x4xi8>, <1024x1024xfp8> >
  %0 = linalg.add ins(%a, %b) outs(%a) { accumulator_type = f32 } : tensor<1024x1024xbf16>

  // Tiling 32 x 32
  %a = tensor.empty() : tensor<1024x1024xbf16>
  %b = mx_tensor.empty() : mx_tensor< <4x4xi8>, <1024x1024xfp8> >
  scf.parallel (%arg0, %arg1) : (0, 0) to (32, 32) {
    %a0 = tensor.extract_slice %a : tensor<32x32xbf16>
    %b0 = mx_tensor.extract_slice %b : mx_tensor< <1x1xi8>, <32x32xfp8> >     // This will need broadcast, since 4x4 isn't large enough
    %0 = linalg.add ins(%a0, %b0) outs(%a0) { accumulator_type = f32 } : tensor<32x32xbf16>
    tensor.insert_slice %0 into %a : tensor<32x32xbf16>
  }

The broadcast on tiling for non-perfectly divisible scaling factors may be a new op or expected semantics of the mx_tensor.extract_slice operation. The former is simpler, but can require materialization of a larger scaling factor in shared memory. The former will need some more complex affine map, but avoids duplication.

Lowering to generics

Like type casts, microscaling factors can be lowered as an operation in the generic region's body.

  // MX type
  %a = tensor.empty() : tensor<1024x1024xbf16>
  %b = mx_tensor.empty() : mx_tensor< <4x4xi8>, <1024x1024xfp8> >
  %0 = linalg.add ins(%a, %b) outs(%a) { accumulator_type = f32 } : tensor<1024x1024xbf16>

  // Generic
  %0 = linalg.generic ins(%a, %b) outs(%a)
            maps { element-wise, broadcast, element-wise }
            iterator { parallel, parallel, parallel }
       {
         // Note: number of input arguments not the same as basic block arguments!
         ^bb0(%arg0: bf16, %arg1: i8, %arg2: fp8):
            // Accumulator cast
            %a1 = cast %arg0 : f32
            // MXFP cast (this is naive `C` lowering)
            %s1 = cast %arg1 : f32
            %b1 = cast %arg2 : f32
            // MXFP scale
            %b2 = mul %s1, %b1 : f32
            // Operation
            %res = add %a1, %b2 : f32
            // Storage cast (this is naive `C` lowering)
            %res1 = cast %res : fp8
            // Return
            yield %res1
      }
      : tensor<1024x1024xbf16>

Note: The casts above are naive and do not work with partial accumulation or contractions. Needs more thought into how to get there before a full proposal.

There are two ways we can do this:

  1. Change the semantics of linalg operations to cope with mx_tensor by knowing it's consistent of 2 or more inner tensors and validate the affine map/iterator type/arguments to match.
  2. Always convert to regular tensors as part of the generalization and just lower to a generic with multiple tensors and the appropriate concat/scale logic.

(1) is harder initially, but can lead to simpler tiling. We won't need to teach the TilingInterface to tile the tensor conversions or their relationships with the consumers. (2) is easier initially, but tiling the generics will get cumbersome quickly.