Skip to content

Generics in MLIR

Renato Golin edited this page Nov 15, 2023 · 4 revisions

Here we explore how to implement generic templates into MLIR, ranging from type specialization, constant expressions and compile-time control flow decisions (based on expressions or type matches).

Extension or Core?

There are two ways of implementing generics in any IR: Develop a meta language / DSL that gets "compiled" into the final (non-generic) IR, or have the semantics in the IR itself. In a way, this is a discussion between C macros and C++ templates.

As an external language / extension / DSL, we'd have:

  • Pros:
    • Different parties can develop different templates for their different needs
    • We don't need to change the core IR semantics
    • Implementation is independent from the evolution of MLIR itself
  • Cons:
    • We need to create a whole new set of tools (parsers, lexers, ASTs) that understands MLIR and can get out of sync with MLIR core
    • Different implementations will need different sets of tools and may conflict with upstream core MLIR semantics
    • Users will need to adopt third-party downstream implementations to get the benefits

As a core MLIR concept, changing the language spect, we'd have:

  • Pros:
    • Co-evolution with core MLIR concepts ensure generic semantics is always up-to-date with current developments
    • No need for external, third-party, downstream tools to get the benefits
    • If semantics is generic enough, downstream can still use it differently for their needs
  • Cons:
    • We'd need to change the core IR semantics (or can it be a new dialect?)
    • Usage will need to be specialized very early to avoid impact other passes having to know about generics
    • Or we fully embrace the new "types", which comes with a high cost

From now on, we only consider changing MLIR core, not to create a DSL.

Furthermore, as a core MLIR concept, we can have it as part of the builtin dialect, or as a new dialect. The former is more succinct, but a lot more invasive and potentially destructive to downstream users. The latter is a lot more verbose and probably will need to duplicate a lot of other dialects' functionality to avoid verification pitfalls.

Motivation

The main motivation for this work is to be able to write kernels in MLIR directly in a way that can be more generic than fat libraries.

Some libraries need to be re-compiled when target decisions change, others can JIT-compile based on target information (or cpuid), and some compilers can generate multiple variations of specific kernels and create a run-time dispatch table populated with profile information. All of these solutions need to be taken care of by this work, but not necessarily from day one.

The first big step is to be able to write a set of micro-kernels in MLIR that can be used for different element types. With hardware extensions being different for different types, we need here both a generic type (ex: f32 || bf16) and compile-time constant expressions to branch and specialize into type-specific code.

For example, a matmul operation for f32 and bf16 on AVX512_VNNI could be written as:

template < class T >
func.func @mma(%arg0 : tensor<*xT>, %arg1 : tensor<*xT>) -> T {
  scf.const_if template::is_same(T, f32) {
    %out = tensor.empty() : tensor<*xf32>
    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<*xf32>, tensor<*xf32>) outs(%out : tensor<*xf32>) -> tensor<*xf32>
    return %0
  } else {
   scf.const_if template::is_same(T, bf16) {
      %out = tensor.empty() : tensor<*xf32>
      %vnni = tensor.vnni_pack (%arg1) // Let's assume there's such a thing
      %1 = linalg.add ins(%arg0, %vnni : tensor<*xf32>, tensor<*xf32>) outs(%out : tensor<*xf32>) -> tensor<*xf32>
      return %1
    } else {
      cf.assert("Unsupported element type");
    }
  }
}

This micro-kernel would be called from code that could have static shapes:

  ...
  %buf = tensor.empty() : tensor<32x128xf32>
  %arg0 = tensor.cast ins(%a: tensor<32x64xf32>) -> tensor<*xf32>
  %arg2 = tensor.cast ins(%b: tensor<64x128xf32>) -> tensor<*xf32>
  %ret = call @mma(%arg0, %arg1) -> tensor<*xf32> // This is unresolved, there is no such function!
  %arg3 = tensor.cast ins(%out: tensor<*xf32>) -> tensor<32x128xf32>
  ...

The specialization pass would create a version for f32 (resolving the const_if expression) and replace:

  ...
  %buf = tensor.empty() : tensor<32x128xf32>
  %arg0 = tensor.cast ins(%a: tensor<32x64xf32>) -> tensor<*xf32>
  %arg2 = tensor.cast ins(%b: tensor<64x128xf32>) -> tensor<*xf32>
  %ret = call @mma_f32(%arg0, %arg1) -> tensor<*xf32> // This is resolved, we just created it
  %out = tensor.cast ins(%out: tensor<*xf32>) -> tensor<32x128xf32>
  ...

After inlining, shape propagation would replace all tensor<*xf32> to their respective types and you'd have:

  ...
  %buf = tensor.empty() : tensor<32x128xf32>
  %out = linalg.matmul ins(%a, %b : tensor<32x64xf32>, tensor<64x128xf32>) outs(%out : tensor<32x128xf32>) -> tensor<32x128xf32>
  ...

As Core MLIR

Generic Types

template < class T >

Generic Shapes

template < shape S >

Generic Values

template < bool hasBF16Support >

ConstExpr

scf.const_if
scf.const_for ?
scf.const_switch ?

template.is_same (type, shape, value)
template.type_switch / shape_switch / value_switch ?

As a dialect

We could create a set of types that verify on uncertainty but need to be converted to native or other types to be lowered. This can compose well with existing dialects, but will lead to verbose and perhaps confusing notation. It could, however, serve as an entry point to the idea.

template.any_type
template.any_type_of<f32, bf16> ?

template.any_shape
template.any_shape_of<32x64, 32x32, 64x64> ?

template.constant true : bool

template.const_if
...

We'd probably also need:
template.func
template.call

The types can be used instead of native/dialect ones:

// Note: tensor<64x64x T> verification will need to be bypassed or we'll need template.tensor, etc.
template.func
    { S : template.any_shape_of<32x64, 32x32, 64x64>, T : template.any_type_of<f32, bf16> }
    @expand(%arg0 : tensor< S x T> ) -> tensor<64x64xT> {
  template.const_shape_switch
    case 64x64 {
      return %arg0
    case 32x64 {
      %0 = tensor.empty : tensor<64x64xT>
      %1 = tensor.insert_slice %arg0 into %0[0, 0][32, 64][1, 1] : ...
      %2 = tensor.insert_slice %arg0 into %0[32, 0][32, 64][1, 1] : ...
      return %0
    }
    ...
}

func.func @caller(...) {
  %0 = tensor.empty() : tensor<32x32xf32>
  %1 = template.call @expand(%0) -> tensor<64x64xf32> // OK

  %2 = tensor.empty() : tensor<32x128xf32>
  %3 = template.call @expand(%2) -> tensor<64x64xf32> // ERROR: Invalid shape

  %4 = tensor.empty() : tensor<32x32xf16>
  %5 = template.call @expand(%4) -> tensor<64x64xf32> // ERROR: Invalid type