Skip to content

Pass Analysis

Renato Golin edited this page Nov 4, 2023 · 17 revisions

This document runs a simple MLP kernel through the compiler and shows the main passes and what do they do, and what is expected that they do.

Input IR

We're looking at a simple 2-layer GEMM+ADD+RELU kernel of 256x128 -> 256x256 -> 256x512.

From the build directory, call: $ mlir-gen --bias --relu

This generates:

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @entry(%arg0: tensor<256x128xf32>) -> tensor<256x512xf32> {

    // WEIGHT + BIAS 1
    %cst = arith.constant dense<1.000000e+00> : tensor<128x256xf32>
    %cst_0 = arith.constant dense<1.000000e+00> : tensor<256x256xf32>

    // BUFFER 1
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<256x256xf32>
    %1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>

    // GEMM 1
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<256x128xf32>, tensor<128x256xf32>) outs(%1 : tensor<256x256xf32>) {
    ^bb0(%in: f32, %in_7: f32, %out: f32):
      %10 = arith.mulf %in, %in_7 : f32
      %11 = arith.addf %out, %10 : f32
      linalg.yield %11 : f32
    } -> tensor<256x256xf32>

    // BIAS ADD 1
    %3 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%cst_0 : tensor<256x256xf32>) outs(%2 : tensor<256x256xf32>) {
    ^bb0(%in: f32, %out: f32):
      %10 = arith.addf %in, %out : f32
      linalg.yield %10 : f32
    } -> tensor<256x256xf32>

    // RELU 1
    %cst_2 = arith.constant 0.000000e+00 : f32
    %4 = linalg.generic {indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} outs(%3 : tensor<256x256xf32>) {
    ^bb0(%out: f32):
      %10 = arith.maximumf %out, %cst_2 : f32
      linalg.yield %10 : f32
    } -> tensor<256x256xf32>

    // WEIGHT + BIAS 2
    %cst_3 = arith.constant dense<1.000000e+00> : tensor<256x512xf32>
    %cst_4 = arith.constant dense<1.000000e+00> : tensor<256x512xf32>

    // BUFFER 2
    %cst_5 = arith.constant 0.000000e+00 : f32
    %5 = tensor.empty() : tensor<256x512xf32>
    %6 = linalg.fill ins(%cst_5 : f32) outs(%5 : tensor<256x512xf32>) -> tensor<256x512xf32>

    // GEMM 2
    %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%4, %cst_3 : tensor<256x256xf32>, tensor<256x512xf32>) outs(%6 : tensor<256x512xf32>) {
    ^bb0(%in: f32, %in_7: f32, %out: f32):
      %10 = arith.mulf %in, %in_7 : f32
      %11 = arith.addf %out, %10 : f32
      linalg.yield %11 : f32
    } -> tensor<256x512xf32>

    // BIAS ADD 2
    %8 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%cst_4 : tensor<256x512xf32>) outs(%7 : tensor<256x512xf32>) {
    ^bb0(%in: f32, %out: f32):
      %10 = arith.addf %in, %out : f32
      linalg.yield %10 : f32
    } -> tensor<256x512xf32>

    // RELU 2
    %cst_6 = arith.constant 0.000000e+00 : f32
    %9 = linalg.generic {indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} outs(%8 : tensor<256x512xf32>) {
    ^bb0(%out: f32):
      %10 = arith.maximumf %out, %cst_6 : f32
      linalg.yield %10 : f32
    } -> tensor<256x512xf32>

    // RETURN
    return %9 : tensor<256x512xf32>
  }
}

How to Run

To look at the compiler passes and their output, just run: $ ./scripts/debug/debug_all_passes.sh

This will generate the IR with mlir-gen, pipe it to tpp-opt through the default TPP passes, dump the IR after each pass, split the result into multiple files and compare them whenever the IR changes.

The default diff program is diff, but you can choose it (ex. vimdiff or meld). You can also pass on the BIN directory where the tools are. By default, the tool assumes it's in the git root build/bin.

$ ./scripts/debug/debug_all_passes.sh -d meld -b /tmp/tpp-mlir-build

The script will dump the output and split the files on a temporary directory created with mktmp and will print it to stdout to help you look at the files directly.

Analysis

Input

The models use a single input, returning a single output, all weights and biases are constants (inference). Each GEMM creates a new buffer (C matrix) and fills it with zeroes. The MLP model reuses the same output buffer to to bias add and ReLU in-place. The final GEMM allocated buffer is returned as the output.

Pack Matmul

As expected, pack-matmul generates 3 packs (input, weight, buffer) and one unpack (GEMM result) for each layer. Note that the packing transposes the blocks, so both matrices (A and B) are traversed horizontally for maximum pre-fetching.

  %2 = tensor.empty() : tensor<8x4x32x32xf32>
  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x128xf32> -> tensor<8x4x32x32xf32>
  %3 = tensor.empty() : tensor<8x4x32x32xf32>
  %pack_3 = tensor.pack %cst_0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor<128x256xf32> -> tensor<8x4x32x32xf32>
  %4 = tensor.empty() : tensor<8x8x32x32xf32>
  %pack_4 = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %4 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
  %5 = linalg.generic ... ins(%pack, %pack_3 : tensor<8x4x32x32xf32>, tensor<8x4x32x32xf32>) outs(%pack_4 : tensor<8x8x32x32xf32>)
  %unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>

  // Add, Relu on original shapes 256x256 ...

  // Original shape tensor.empty for the next layer

Propagate Pack & Unpack

This pass propagates the packs/unpacks by replicating the same packing on element-wise as was done on GEMMs above. The result is a multiplication of tensor.pack and tensor.unpack operations across every linalg op. However, because the unpack is followed by an identical pack, it gets simplified later.

  // Input, weight and accumulation packing
  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x128xf32> -> tensor<8x4x32x32xf32>
  %pack_3 = tensor.pack %cst_0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor<128x256xf32> -> tensor<8x4x32x32xf32>
  %pack_4 = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %4 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>

  // GEMM
  %5 = linalg.generic { ... } -> tensor<8x8x32x32xf32>

  // Unpack
  %unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>

  // GEMM's output re-pack & bias pack
  %pack_5 = tensor.pack %unpack inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %6 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
  %pack_6 = tensor.pack %cst_1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %7 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>

  // BIAS ADD
  %8 = linalg.generic { ... } -> tensor<8x8x32x32xf32>

  // Unpack
  %unpack_7 = tensor.unpack %8 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>

  // ADD's output re-pack
  %pack_8 = tensor.pack %unpack_7 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %9 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>

  // RELU
  %10 = linalg.generic { ... } -> tensor<8x8x32x32xf32>

  // NOTE: RELU's output is not unpacked, and the input pack to the next layer is removed
  // This works fine and may be a side-effect of the packing (not add the last one)
  // But it can also be a BUG that just happens to work and isn't producing correct code

Constant Fold Pack

As expected, because the constants are splat, the "constant folding" is just a "constant reshape", so we get:

  // Original
  %cst = arith.constant dense<1.000000e+00> : tensor<256x512xf32>
  ...
  %3 = tensor.empty() : tensor<8x4x32x32xf32>
  %pack_3 = tensor.pack %cst_0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor<128x256xf32> -> tensor<8x4x32x32xf32>

  // Reshaped
  %cst = arith.constant dense<1.000000e+00> : tensor<16x8x32x32xf32>
  ...
  // NOTE: The old packed buffer continues to exist and could be removed by the pass
  // But following cleanups will remove it anyway
  %3 = tensor.empty() : tensor<8x4x32x32xf32>

The buffer allocation is also "constant folded", even if it isn't a constant per se, it's initialized with zeroes, so can be "constant reshaped":

  // From
  %0 = tensor.empty() : tensor<256x256xf32>
  %1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
  ...
  %4 = tensor.empty() : tensor<8x8x32x32xf32>
  %pack_4 = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %4 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>

  // To
  // NOTE: The original unpacked tensor continues to exist, as the output of the GEMM is still unpacked into this buffer
  // This will be simplified on the next pass (since we'll remove the `unpack`/`pack` pair)
  %0 = tensor.empty() : tensor<256x256xf32>
  %1 = linalg.fill ins(%cst_6 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
  ...
  %4 = tensor.empty() : tensor<8x8x32x32xf32>
  %5 = linalg.fill ins(%cst_6 : f32) outs(%4 : tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32>

Simplify Pack

In the MLP model, it propagates the packed shape through the following element-wise ops. On the GEMM-only kernel, it fuses the unpack with the next layer's pack directly.

  // GEMM on packed shapes, 
  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst_1 : tensor<8x4x32x32xf32>, tensor<8x4x32x32xf32>) outs(%2 : tensor<8x8x32x32xf32>) {
  ^bb0(%in: f32, %in_4: f32, %out: f32):
    %13 = arith.mulf %in, %in_4 : f32
    %14 = arith.addf %out, %13 : f32
    linalg.yield %14 : f32
  } -> tensor<8x8x32x32xf32>
  // No more tensor.unpack

  // Bias Add on the packed shape
  %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_2 : tensor<8x8x32x32xf32>) outs(%3 : tensor<8x8x32x32xf32>) {
  ^bb0(%in: f32, %out: f32):
    %13 = arith.addf %in, %out : f32
    linalg.yield %13 : f32
  } -> tensor<8x8x32x32xf32>

  // ReLU on the packed shape
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%4 : tensor<8x8x32x32xf32>) {
  ^bb0(%out: f32):
    %13 = arith.maximumf %out, %cst_3 : f32
    linalg.yield %13 : f32
  } -> tensor<8x8x32x32xf32>

  // Next GEMM takes the packed shape %5 directly
  %10 = linalg.generic ... ins(%5, %cst : tensor<8x8x32x32xf32>, tensor<16x8x32x32xf32>) outs(%9 : tensor<8x16x32x32xf32>)

At the end of this pass, at least for MLP-like models, the kernel code should only contain one pack (for the input, at the beginning) and one unpack (for the output, at the end, before return).

All of the previous unpacked constant tensors and empty buffers should have been removed too.

Tile And Fuse

This pass creates scf.forall loops and adds the tile linalg ops inside. Since we're fusing GEMMs with element-wise in MLP, the "tile" operation is a BRGEMM, not a GEMM, so the final op becomes linalg.batch_reduce_matmul instead of just linalg.matmul.

The final pattern is:

  // NOTE: Because we're using splat constant, the tiling just reduces the dimensionality of the tensor
  // On real benchmarks we use random tensors, so this becomes an `extract_slice` like the input
  %cst = arith.constant dense<1.000000e+00> : tensor<4x32x32xf32>

  // C matrix, note the `zero` was inlined in the loop
  // BUG: The second layer is not "moved" inside, there's still a zero outside and one inside
  %1 = tensor.empty() : tensor<8x8x32x32xf32>

  // Parallel loop
  %2 = scf.forall (%arg1, %arg2) in (8, 8) shared_outs(%arg3 = %1) -> (tensor<8x8x32x32xf32>) {
    // C matrix tile + zero
    %extracted_slice = tensor.extract_slice %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x8x32x32xf32> to tensor<32x32xf32>
    %7 = linalg.fill ins(%cst_1 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>

    // Input tile
    %extracted_slice_2 = tensor.extract_slice %pack[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : tensor<8x4x32x32xf32> to tensor<4x32x32xf32>

    // BRGEMM on block line x column
    %8 = linalg.batch_reduce_matmul ins(%extracted_slice_2, %cst : tensor<4x32x32xf32>, tensor<4x32x32xf32>) outs(%7 : tensor<32x32xf32>) -> tensor<32x32xf32>

    // Bias add and ReLU on 2D tile
    %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<32x32xf32>) outs(%8 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = arith.addf %in, %out : f32
      linalg.yield %11 : f32
    } -> tensor<32x32xf32>
    %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%9 : tensor<32x32xf32>) {
    ^bb0(%out: f32):
      %11 = arith.maximumf %out, %cst_2 : f32
      linalg.yield %11 : f32
    } -> tensor<32x32xf32>

    // Write back into C matrix (async)
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %10 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x8x32x32xf32>
    }
  }

The reason why the zero was moved inside the loop is because we can match zero + gemm to gemm_with_beta_zero. The second layer is done in the same way.

Simplify Pack (again)

This second pass will "consume" the last unpack into the scf.forall, so that we write directly to the right tile instead of needing to go through the entire tensor again. The first layer has no unpack and its consumer (the second layer) is also packed, so nothing happens with it.

  // Input packing & packed buffer allocation
  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<256x128xf32> -> tensor<8x4x32x32xf32>
  %1 = tensor.empty() : tensor<8x8x32x32xf32>

  // Layer 1 (note %arg3 is still packed)
  %2 = scf.forall (%arg1, %arg2) in (8, 8) shared_outs(%arg3 = %1) -> (tensor<8x8x32x32xf32>) {
    // BRGEMM + ADD + RELU in a packed shape

    // Parallel write into packed shape
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %9 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x8x32x32xf32>
    }
  }

  // Final return (unpacked) BUG: Note the unnecessary whole-tensor zero
  %3 = tensor.empty() : tensor<256x512xf32>
  %4 = linalg.fill ins(%cst_2 : f32) outs(%3 : tensor<256x512xf32>) -> tensor<256x512xf32>

  // Layer 2 (note %arg3 is now unpacked)
  %5 = scf.forall (%arg1, %arg2) in (8, 16) shared_outs(%arg3 = %4) -> (tensor<256x512xf32>) {
    // BRGEMM + ADD + RELU in a packed shape

    // Parallel write into unpacked shape (note the equivalent dimensionality)
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %11 into %arg3[%6, %7] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<256x512xf32>
    }
  }
  return %5 : tensor<256x512xf32>

Lower Packs & Unpacks

This basically lowers a tensor.pack and tensor.unpack into loops. We detect block transposes and lower them to tensor.extract_slice + tensor.parallel_insert_slice to speed up block copies.

  // New buffer needs to be allocated
  %0 = tensor.empty() : tensor<8x4x32x32xf32>

  // Parallel loop
  %1 = scf.forall (%arg1, %arg2) in (8, 4) shared_outs(%arg3 = %0) -> (tensor<8x4x32x32xf32>) {
    // Block of 32x32
    %7 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg1)
    %8 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg2)

    // Extract the 32x32 slice from %arg0[%arg1, %arg2]
    %extracted_slice = tensor.extract_slice %arg0[%7, %8] [32, 32] [1, 1] : tensor<256x128xf32> to tensor<32x32xf32>

    // Insert 32x32 slice into %arg3. NOTE: Shouldn't this be %arg3[%arg2, %arg1]?
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %extracted_slice into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x4x32x32xf32>
    }
  }

This pattern will lower to memref.copy which is more efficient than single element copies.

One-shot-bufferize

This pass (and the subsequent bufferization cleanups) convert tensors to memrefs and make sure all required buffers are allocated. The objective here is to allocate as few as possible and to reuse as much as possible.

Currently, bufferization only reuses if in the same chain (in-place computation), but it would be beneficial also to reuse when the previous use is dead and the new shape is the same than the original buffer.

Note, this only adds memref.alloc, not memref.dealloc, which will be added on following passes.

Constant tensors become global memrefs. Constant scalar continues as arith.constant.

This is done as a pair of memref.global and memref.get_global:

  memref.global "private" constant @__constant_32x32xf32 : memref<32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
  memref.global "private" constant @__constant_8x32x32xf32 : memref<8x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
  memref.global "private" constant @__constant_4x32x32xf32 : memref<4x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
  func.func @entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> {
    %0 = memref.get_global @__constant_4x32x32xf32 : memref<4x32x32xf32>
    %1 = memref.get_global @__constant_8x32x32xf32 : memref<8x32x32xf32>
    %2 = memref.get_global @__constant_32x32xf32 : memref<32x32xf32>

tensor.empty gets converted into memref.alloc

Only when it can't be reused:

    %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>

scf.forall in "memory semantics" (no shared_outs)

    scf.forall (%arg1, %arg2) in (8, 4) {
      %3 = affine.apply #map(%arg1)
      %4 = affine.apply #map(%arg2)
      ...
      // No more `tensor.parallel_insert_slice`, writes are done directly to the buffers
    }

tensor.extract_slice converts to memref.subview

    scf.forall (%arg1, %arg2) in (8, 4) {
      ...

      // This is the `tensor.pack` converted into tile copies above, now very clear
      %subview = memref.subview %arg0[%3, %4] [32, 32] [1, 1] : memref<256x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
      %subview_2 = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
      memref.copy %subview, %subview_2 : memref<32x32xf32, strided<[128, 1], offset: ?>> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    }

Linalg ops converted to memref

Now, the outs really means "write to that buffer":

    scf.forall (%arg1, %arg2) in (8, 8) {
      ...
      linalg.batch_reduce_matmul ins(%subview_2, %0 : memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)

      linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%2 : memref<32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) {
      ^bb0(%in: f32, %out: f32):
        %3 = arith.addf %in, %out : f32
        linalg.yield %3 : f32
      }

      linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) {
      ^bb0(%out: f32):
        %3 = arith.maximumf %out, %cst : f32
        linalg.yield %3 : f32
      }
      ...
    }

However, there's a duplication at the end of the loop that will be cleaned up by later passes:

    // New buffer (C)
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>
    // Layer loop
    scf.forall (%arg1, %arg2) in (8, 8) {
      // C tile
      %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
      // C = <0>
      linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
      // Input
      %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>
      // Compute
      linalg.batch_reduce_matmul
      linalg.generic { arith.addf }
      linalg.generic { arith.maximumf }

      // NOTE: Redundant copy, bufferization is not smart enough, needs cleanup later
      %subview_3 = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
      memref.copy %subview, %subview_3 : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<32x32xf32, strided<[32, 1], offset: ?>>
    }

Ownership based deallocation

This is the pass that creates memref.dealloc based on ownership. In our example (2 layers), there are three main allocations (%alloc, %alloc_0 and %alloc_1). The first (%alloc) is for the tensor.pack, is used by the first layer, and it's dead by the end of the first layer. The second (%alloc_0) is used as input to the second, but at the end of the second layer, it's dead, so it has to be deallocated. The third (%alloc_1) is returned by the function, so it's still alive after exit and cannot be deallocated.

  func.func @entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> {
    ...

    // Layer 1
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>
    scf.forall (%arg1, %arg2) in (8, 8) {
      ...
    }

    // Layer 2
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xf32>
    // Note: redundant fill is still here
    linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<256x512xf32>)
    scf.forall (%arg1, %arg2) in (8, 16) {
      ...
    }

    // Return, retaining %alloc_1
    %3 = bufferization.dealloc (%alloc, %alloc_0, %alloc_1 : memref<8x4x32x32xf32>, memref<8x8x32x32xf32>, memref<256x512xf32>) if (%true, %true, %true) retain (%alloc_1 : memref<256x512xf32>)
    return %alloc_1 : memref<256x512xf32>
  }

Buffer deallocation simplification

Converts bufferization.dealloc into the appropriate memref.dealloc:

    ...
    memref.dealloc %alloc : memref<8x4x32x32xf32>
    memref.dealloc %alloc_0 : memref<8x8x32x32xf32>
    return %alloc_1 : memref<256x512xf32>
  }
}

NOTE: The first allocation could have been freed by the end of the first layer, but since it goes through bufferization.dealloc, all deallocations happen at the end of the model. This is fine from a memory safety point of view, but larger models could unnecessarily run out of memory.

Memref to XSMM

This pass converts memref operations into XSMM calls. For now, this is only for the packing, where we use memref.copy, which gets converted to an XSMM's identity (copy).

  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
  scf.forall (%arg1, %arg2) in (8, 4) {
    %3 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg1)
    %4 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg2)
    %subview = memref.subview %arg0[%3, %4] [32, 32] [1, 1] : memref<256x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
    %subview_2 = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>

    // NOTE: The dispatch is the function that compiles the kernel and only needs to be run once.
    // It will be hoisted and commoned up on a later pass
    %5 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32

    // This is identical to `memref.copy`
    xsmm.unary identity(data_type = f32, %5, %subview, %subview_2) : (i64, memref<32x32xf32, strided<[128, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
  }

Linalg to XSMM

This is the core transform that converts generic Linalg operations into XSMM calls. The previous pass only converts the packs to XSMM, this one converts everything else.

  // Buffer 1
  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>

  // Layer 1
  scf.forall (%arg1, %arg2) in (8, 8) {

    // Tiles
    %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>

    // Fill with zero -> XSMM ZERO
    // NOTE: This can be converted to BRGEMM(BETA=0) later
    %4 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
    xsmm.unary zero(data_type = f32, %4, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>

    // BRGEMM
    // NOTE: This can be fused into FUSED_BRGEMM(BIN=ADD, UN=RELU)
    %5 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
    xsmm.brgemm(data_type = f32, %5, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()

    // BIAS ADD
    %6 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.binary add(data_type = f32, %6, %2, %subview, %subview) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()

    // RELU
    %7 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
    xsmm.unary relu(data_type = f32, %7, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
  }

NOTE: Still all dispatches are inside the inner loop and will be hoisted and commoned up later.

Convert scf.forall to scf.parallel

Since our use here means the same thing, and since the OpenMP pass only detects the latter, we need to convert to get OpenMP functionality.

  // From
  scf.forall (%arg1, %arg2) in (8, 8) {
    ...
  }

  // To
  scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
    ...
    scf.yield
  }

Loop Invariant Code Motion (LICM)

This is the pass that hoists all loop-constant calls (our dispatches):

  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>

  // All dispatches outside of the loop now
  %4 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
  %5 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
  %6 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
  %7 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32

  // Layer 1 loop
  scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
    // C tile
    %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>

    // C = <0>
    xsmm.unary zero(data_type = f32, %4, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()

    // Input tile
    %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>

    // BRGEMM: only execution here
    xsmm.brgemm(data_type = f32, %5, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
    xsmm.binary add(data_type = f32, %6, %2, %subview, %subview) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    xsmm.unary relu(data_type = f32, %7, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    scf.yield
  }

XSMM Fusion

This (future #752) pass looks at a sequence of XSMM operations and tried to fused them according to the target's rules. LIBXSMM supports fused BRGEMMs, where a single function call handles the GEMM-like, a binary and a unary operation (configurable), and has flags such as BETA=0 that allows us to elide the initial xsmm.unary zero.

NOTE: The initial zero inside the loop isn't costly (operating on the tile, so cache-friendly), and the same can be said for the following element-wise operations. However, this is only the case for simple patterns. For example, if the element-wise don't bufferize correctly and add allocs in between, fusing will remove that complexity and return performance to base levels.

TODO: Once the PR is merged, add IR for this pass.

Convert XSMM to func

Once all XSMM ops are lowered and fused, we convert them to function calls. These functions are available in our runtime library that marshals arguments and return values and calls LIBXSMM functions directly.

  // From
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
  %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32
  scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c4) step (%c1, %c1) {
    ...
    xsmm.unary identity(data_type = f32, %3, %subview, %subview_2) : (i64, memref<32x32xf32, strided<[128, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
    scf.yield
  }

  // To
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
  %3 = call @xsmm_unary_dispatch(%c1_i64, %c1_i64, %c32_i64, %c32_i64, %c128_i64, %c32_i64, %c0_i64) : (i64, i64, i64, i64, i64, i64, i64) -> i64
  scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c4) step (%c1, %c1) {
    ...

    // NOTE: This marshalling is generated by the compiler, to extract MLIR memref semantics
    // Setting up LIBXSMM's arguments and it's flags is done in the runtime shim layer
    %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : memref<32x32xf32, strided<[128, 1], offset: ?>> -> memref<f32>, index, index, index, index, index
    %intptr_3 = memref.extract_aligned_pointer_as_index %subview : memref<32x32xf32, strided<[128, 1], offset: ?>> -> index
    %17 = arith.index_cast %intptr_3 : index to i64
    %18 = llvm.inttoptr %17 : i64 to !llvm.ptr<f32>
    %base_buffer_4, %offset_5, %sizes_6:2, %strides_7:2 = memref.extract_strided_metadata %subview_2 : memref<32x32xf32, strided<[32, 1], offset: ?>> -> memref<f32>, index, index, index, index, index
    %intptr_8 = memref.extract_aligned_pointer_as_index %subview_2 : memref<32x32xf32, strided<[32, 1], offset: ?>> -> index
    %19 = arith.index_cast %intptr_8 : index to i64
    %20 = llvm.inttoptr %19 : i64 to !llvm.ptr<f32>

    // Invoke call
    func.call @xsmm_unary_invoke(%c1_i64, %3, %18, %offset, %20, %offset_5) : (i64, i64, !llvm.ptr<f32>, index, !llvm.ptr<f32>, index) -> ()
    scf.yield
  }