-
Notifications
You must be signed in to change notification settings - Fork 31
Pass Analysis
This document runs a simple MLP kernel through the compiler and shows the main passes and what do they do, and what is expected that they do.
We're looking at a simple 2-layer GEMM+ADD+RELU kernel of 256x128 -> 256x256 -> 256x512.
From the build directory, call:
$ mlir-gen --bias --relu
This generates:
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @entry(%arg0: tensor<256x128xf32>) -> tensor<256x512xf32> {
// WEIGHT + BIAS 1
%cst = arith.constant dense<1.000000e+00> : tensor<128x256xf32>
%cst_0 = arith.constant dense<1.000000e+00> : tensor<256x256xf32>
// BUFFER 1
%cst_1 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<256x256xf32>
%1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
// GEMM 1
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<256x128xf32>, tensor<128x256xf32>) outs(%1 : tensor<256x256xf32>) {
^bb0(%in: f32, %in_7: f32, %out: f32):
%10 = arith.mulf %in, %in_7 : f32
%11 = arith.addf %out, %10 : f32
linalg.yield %11 : f32
} -> tensor<256x256xf32>
// BIAS ADD 1
%3 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%cst_0 : tensor<256x256xf32>) outs(%2 : tensor<256x256xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<256x256xf32>
// RELU 1
%cst_2 = arith.constant 0.000000e+00 : f32
%4 = linalg.generic {indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} outs(%3 : tensor<256x256xf32>) {
^bb0(%out: f32):
%10 = arith.maximumf %out, %cst_2 : f32
linalg.yield %10 : f32
} -> tensor<256x256xf32>
// WEIGHT + BIAS 2
%cst_3 = arith.constant dense<1.000000e+00> : tensor<256x512xf32>
%cst_4 = arith.constant dense<1.000000e+00> : tensor<256x512xf32>
// BUFFER 2
%cst_5 = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<256x512xf32>
%6 = linalg.fill ins(%cst_5 : f32) outs(%5 : tensor<256x512xf32>) -> tensor<256x512xf32>
// GEMM 2
%7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%4, %cst_3 : tensor<256x256xf32>, tensor<256x512xf32>) outs(%6 : tensor<256x512xf32>) {
^bb0(%in: f32, %in_7: f32, %out: f32):
%10 = arith.mulf %in, %in_7 : f32
%11 = arith.addf %out, %10 : f32
linalg.yield %11 : f32
} -> tensor<256x512xf32>
// BIAS ADD 2
%8 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%cst_4 : tensor<256x512xf32>) outs(%7 : tensor<256x512xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<256x512xf32>
// RELU 2
%cst_6 = arith.constant 0.000000e+00 : f32
%9 = linalg.generic {indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} outs(%8 : tensor<256x512xf32>) {
^bb0(%out: f32):
%10 = arith.maximumf %out, %cst_6 : f32
linalg.yield %10 : f32
} -> tensor<256x512xf32>
// RETURN
return %9 : tensor<256x512xf32>
}
}
To look at the compiler passes and their output, just run:
$ ./scripts/debug/debug_all_passes.sh
This will generate the IR with mlir-gen
, pipe it to tpp-opt
through the default TPP passes, dump the IR after each pass, split the result into multiple files and compare them whenever the IR changes.
The default diff program is diff
, but you can choose it (ex. vimdiff
or meld
).
You can also pass on the BIN directory where the tools are.
By default, the tool assumes it's in the git root build/bin
.
$ ./scripts/debug/debug_all_passes.sh -d meld -b /tmp/tpp-mlir-build
The script will dump the output and split the files on a temporary directory created with mktmp
and will print it to stdout to help you look at the files directly.
The models use a single input, returning a single output, all weights and biases are constants (inference). Each GEMM creates a new buffer (C matrix) and fills it with zeroes. The MLP model reuses the same output buffer to to bias add and ReLU in-place. The final GEMM allocated buffer is returned as the output.
As expected, pack-matmul
generates 3 packs (input, weight, buffer) and one unpack (GEMM result) for each layer.
Note that the packing transposes the blocks, so both matrices (A and B) are traversed horizontally for maximum pre-fetching.
%2 = tensor.empty() : tensor<8x4x32x32xf32>
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x128xf32> -> tensor<8x4x32x32xf32>
%3 = tensor.empty() : tensor<8x4x32x32xf32>
%pack_3 = tensor.pack %cst_0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor<128x256xf32> -> tensor<8x4x32x32xf32>
%4 = tensor.empty() : tensor<8x8x32x32xf32>
%pack_4 = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %4 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
%5 = linalg.generic ... ins(%pack, %pack_3 : tensor<8x4x32x32xf32>, tensor<8x4x32x32xf32>) outs(%pack_4 : tensor<8x8x32x32xf32>)
%unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
// Add, Relu on original shapes 256x256 ...
// Original shape tensor.empty for the next layer
This pass propagates the packs/unpacks by replicating the same packing on element-wise as was done on GEMMs above.
The result is a multiplication of tensor.pack
and tensor.unpack
operations across every linalg op.
However, because the unpack
is followed by an identical pack
, it gets simplified later.
// Input, weight and accumulation packing
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<256x128xf32> -> tensor<8x4x32x32xf32>
%pack_3 = tensor.pack %cst_0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor<128x256xf32> -> tensor<8x4x32x32xf32>
%pack_4 = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %4 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
// GEMM
%5 = linalg.generic { ... } -> tensor<8x8x32x32xf32>
// Unpack
%unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
// GEMM's output re-pack & bias pack
%pack_5 = tensor.pack %unpack inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %6 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
%pack_6 = tensor.pack %cst_1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %7 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
// BIAS ADD
%8 = linalg.generic { ... } -> tensor<8x8x32x32xf32>
// Unpack
%unpack_7 = tensor.unpack %8 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
// ADD's output re-pack
%pack_8 = tensor.pack %unpack_7 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %9 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
// RELU
%10 = linalg.generic { ... } -> tensor<8x8x32x32xf32>
// NOTE: RELU's output is not unpacked, and the input pack to the next layer is removed
// This works fine and may be a side-effect of the packing (not add the last one)
// But it can also be a BUG that just happens to work and isn't producing correct code
As expected, because the constants are splat, the "constant folding" is just a "constant reshape", so we get:
// Original
%cst = arith.constant dense<1.000000e+00> : tensor<256x512xf32>
...
%3 = tensor.empty() : tensor<8x4x32x32xf32>
%pack_3 = tensor.pack %cst_0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor<128x256xf32> -> tensor<8x4x32x32xf32>
// Reshaped
%cst = arith.constant dense<1.000000e+00> : tensor<16x8x32x32xf32>
...
// NOTE: The old packed buffer continues to exist and could be removed by the pass
// But following cleanups will remove it anyway
%3 = tensor.empty() : tensor<8x4x32x32xf32>
The buffer allocation is also "constant folded", even if it isn't a constant per se, it's initialized with zeroes, so can be "constant reshaped":
// From
%0 = tensor.empty() : tensor<256x256xf32>
%1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
...
%4 = tensor.empty() : tensor<8x8x32x32xf32>
%pack_4 = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %4 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
// To
// NOTE: The original unpacked tensor continues to exist, as the output of the GEMM is still unpacked into this buffer
// This will be simplified on the next pass (since we'll remove the `unpack`/`pack` pair)
%0 = tensor.empty() : tensor<256x256xf32>
%1 = linalg.fill ins(%cst_6 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
...
%4 = tensor.empty() : tensor<8x8x32x32xf32>
%5 = linalg.fill ins(%cst_6 : f32) outs(%4 : tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32>
In the MLP model, it propagates the packed shape through the following element-wise ops.
On the GEMM-only kernel, it fuses the unpack
with the next layer's pack
directly.
// GEMM on packed shapes,
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst_1 : tensor<8x4x32x32xf32>, tensor<8x4x32x32xf32>) outs(%2 : tensor<8x8x32x32xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%13 = arith.mulf %in, %in_4 : f32
%14 = arith.addf %out, %13 : f32
linalg.yield %14 : f32
} -> tensor<8x8x32x32xf32>
// No more tensor.unpack
// Bias Add on the packed shape
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_2 : tensor<8x8x32x32xf32>) outs(%3 : tensor<8x8x32x32xf32>) {
^bb0(%in: f32, %out: f32):
%13 = arith.addf %in, %out : f32
linalg.yield %13 : f32
} -> tensor<8x8x32x32xf32>
// ReLU on the packed shape
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%4 : tensor<8x8x32x32xf32>) {
^bb0(%out: f32):
%13 = arith.maximumf %out, %cst_3 : f32
linalg.yield %13 : f32
} -> tensor<8x8x32x32xf32>
// Next GEMM takes the packed shape %5 directly
%10 = linalg.generic ... ins(%5, %cst : tensor<8x8x32x32xf32>, tensor<16x8x32x32xf32>) outs(%9 : tensor<8x16x32x32xf32>)
At the end of this pass, at least for MLP-like models, the kernel code should only contain one pack
(for the input, at the beginning) and one unpack
(for the output, at the end, before return
).
All of the previous unpacked constant tensors and empty buffers should have been removed too.
This pass creates scf.forall
loops and adds the tile linalg
ops inside. Since we're fusing GEMMs with element-wise in MLP, the "tile" operation is a BRGEMM, not a GEMM, so the final op becomes linalg.batch_reduce_matmul
instead of just linalg.matmul
.
The final pattern is:
// NOTE: Because we're using splat constant, the tiling just reduces the dimensionality of the tensor
// On real benchmarks we use random tensors, so this becomes an `extract_slice` like the input
%cst = arith.constant dense<1.000000e+00> : tensor<4x32x32xf32>
// C matrix, note the `zero` was inlined in the loop
// BUG: The second layer is not "moved" inside, there's still a zero outside and one inside
%1 = tensor.empty() : tensor<8x8x32x32xf32>
// Parallel loop
%2 = scf.forall (%arg1, %arg2) in (8, 8) shared_outs(%arg3 = %1) -> (tensor<8x8x32x32xf32>) {
// C matrix tile + zero
%extracted_slice = tensor.extract_slice %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x8x32x32xf32> to tensor<32x32xf32>
%7 = linalg.fill ins(%cst_1 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
// Input tile
%extracted_slice_2 = tensor.extract_slice %pack[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : tensor<8x4x32x32xf32> to tensor<4x32x32xf32>
// BRGEMM on block line x column
%8 = linalg.batch_reduce_matmul ins(%extracted_slice_2, %cst : tensor<4x32x32xf32>, tensor<4x32x32xf32>) outs(%7 : tensor<32x32xf32>) -> tensor<32x32xf32>
// Bias add and ReLU on 2D tile
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<32x32xf32>) outs(%8 : tensor<32x32xf32>) {
^bb0(%in: f32, %out: f32):
%11 = arith.addf %in, %out : f32
linalg.yield %11 : f32
} -> tensor<32x32xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%9 : tensor<32x32xf32>) {
^bb0(%out: f32):
%11 = arith.maximumf %out, %cst_2 : f32
linalg.yield %11 : f32
} -> tensor<32x32xf32>
// Write back into C matrix (async)
scf.forall.in_parallel {
tensor.parallel_insert_slice %10 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x8x32x32xf32>
}
}
The reason why the zero was moved inside the loop is because we can match zero + gemm
to gemm_with_beta_zero
.
The second layer is done in the same way.
This second pass will "consume" the last unpack
into the scf.forall
, so that we write directly to the right tile instead of needing to go through the entire tensor again.
The first layer has no unpack and its consumer (the second layer) is also packed, so nothing happens with it.
// Input packing & packed buffer allocation
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<256x128xf32> -> tensor<8x4x32x32xf32>
%1 = tensor.empty() : tensor<8x8x32x32xf32>
// Layer 1 (note %arg3 is still packed)
%2 = scf.forall (%arg1, %arg2) in (8, 8) shared_outs(%arg3 = %1) -> (tensor<8x8x32x32xf32>) {
// BRGEMM + ADD + RELU in a packed shape
// Parallel write into packed shape
scf.forall.in_parallel {
tensor.parallel_insert_slice %9 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x8x32x32xf32>
}
}
// Final return (unpacked) BUG: Note the unnecessary whole-tensor zero
%3 = tensor.empty() : tensor<256x512xf32>
%4 = linalg.fill ins(%cst_2 : f32) outs(%3 : tensor<256x512xf32>) -> tensor<256x512xf32>
// Layer 2 (note %arg3 is now unpacked)
%5 = scf.forall (%arg1, %arg2) in (8, 16) shared_outs(%arg3 = %4) -> (tensor<256x512xf32>) {
// BRGEMM + ADD + RELU in a packed shape
// Parallel write into unpacked shape (note the equivalent dimensionality)
scf.forall.in_parallel {
tensor.parallel_insert_slice %11 into %arg3[%6, %7] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<256x512xf32>
}
}
return %5 : tensor<256x512xf32>
This basically lowers a tensor.pack
and tensor.unpack
into loops.
We detect block transposes and lower them to tensor.extract_slice
+ tensor.parallel_insert_slice
to speed up block copies.
// New buffer needs to be allocated
%0 = tensor.empty() : tensor<8x4x32x32xf32>
// Parallel loop
%1 = scf.forall (%arg1, %arg2) in (8, 4) shared_outs(%arg3 = %0) -> (tensor<8x4x32x32xf32>) {
// Block of 32x32
%7 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg1)
%8 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg2)
// Extract the 32x32 slice from %arg0[%arg1, %arg2]
%extracted_slice = tensor.extract_slice %arg0[%7, %8] [32, 32] [1, 1] : tensor<256x128xf32> to tensor<32x32xf32>
// Insert 32x32 slice into %arg3. NOTE: Shouldn't this be %arg3[%arg2, %arg1]?
scf.forall.in_parallel {
tensor.parallel_insert_slice %extracted_slice into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x4x32x32xf32>
}
}
This pattern will lower to memref.copy
which is more efficient than single element copies.
This pass (and the subsequent bufferization cleanups) convert tensor
s to memref
s and make sure all required buffers are allocated. The objective here is to allocate as few as possible and to reuse as much as possible.
Currently, bufferization only reuses if in the same chain (in-place computation), but it would be beneficial also to reuse when the previous use is dead and the new shape is the same than the original buffer.
Note, this only adds memref.alloc
, not memref.dealloc
, which will be added on following passes.
This is done as a pair of memref.global
and memref.get_global
:
memref.global "private" constant @__constant_32x32xf32 : memref<32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_8x32x32xf32 : memref<8x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_4x32x32xf32 : memref<4x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> {
%0 = memref.get_global @__constant_4x32x32xf32 : memref<4x32x32xf32>
%1 = memref.get_global @__constant_8x32x32xf32 : memref<8x32x32xf32>
%2 = memref.get_global @__constant_32x32xf32 : memref<32x32xf32>
Only when it can't be reused:
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
scf.forall (%arg1, %arg2) in (8, 4) {
%3 = affine.apply #map(%arg1)
%4 = affine.apply #map(%arg2)
...
// No more `tensor.parallel_insert_slice`, writes are done directly to the buffers
}
scf.forall (%arg1, %arg2) in (8, 4) {
...
// This is the `tensor.pack` converted into tile copies above, now very clear
%subview = memref.subview %arg0[%3, %4] [32, 32] [1, 1] : memref<256x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
%subview_2 = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
memref.copy %subview, %subview_2 : memref<32x32xf32, strided<[128, 1], offset: ?>> to memref<32x32xf32, strided<[32, 1], offset: ?>>
}
Now, the outs
really means "write to that buffer":
scf.forall (%arg1, %arg2) in (8, 8) {
...
linalg.batch_reduce_matmul ins(%subview_2, %0 : memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%2 : memref<32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) {
^bb0(%in: f32, %out: f32):
%3 = arith.addf %in, %out : f32
linalg.yield %3 : f32
}
linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) {
^bb0(%out: f32):
%3 = arith.maximumf %out, %cst : f32
linalg.yield %3 : f32
}
...
}
However, there's a duplication at the end of the loop that will be cleaned up by later passes:
// New buffer (C)
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>
// Layer loop
scf.forall (%arg1, %arg2) in (8, 8) {
// C tile
%subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
// C = <0>
linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
// Input
%subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// Compute
linalg.batch_reduce_matmul
linalg.generic { arith.addf }
linalg.generic { arith.maximumf }
// NOTE: Redundant copy, bufferization is not smart enough, needs cleanup later
%subview_3 = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
memref.copy %subview, %subview_3 : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<32x32xf32, strided<[32, 1], offset: ?>>
}
This is the pass that creates memref.dealloc
based on ownership.
In our example (2 layers), there are three main allocations (%alloc
, %alloc_0
and %alloc_1
).
The first (%alloc
) is for the tensor.pack
, is used by the first layer, and it's dead by the end of the first layer.
The second (%alloc_0
) is used as input to the second, but at the end of the second layer, it's dead, so it has to be deallocated.
The third (%alloc_1
) is returned by the function, so it's still alive after exit and cannot be deallocated.
func.func @entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> {
...
// Layer 1
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>
scf.forall (%arg1, %arg2) in (8, 8) {
...
}
// Layer 2
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xf32>
// Note: redundant fill is still here
linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<256x512xf32>)
scf.forall (%arg1, %arg2) in (8, 16) {
...
}
// Return, retaining %alloc_1
%3 = bufferization.dealloc (%alloc, %alloc_0, %alloc_1 : memref<8x4x32x32xf32>, memref<8x8x32x32xf32>, memref<256x512xf32>) if (%true, %true, %true) retain (%alloc_1 : memref<256x512xf32>)
return %alloc_1 : memref<256x512xf32>
}
Converts bufferization.dealloc
into the appropriate memref.dealloc
:
...
memref.dealloc %alloc : memref<8x4x32x32xf32>
memref.dealloc %alloc_0 : memref<8x8x32x32xf32>
return %alloc_1 : memref<256x512xf32>
}
}
NOTE: The first allocation could have been freed by the end of the first layer, but since it goes through bufferization.dealloc
, all deallocations happen at the end of the model.
This is fine from a memory safety point of view, but larger models could unnecessarily run out of memory.
This pass converts memref
operations into XSMM calls.
For now, this is only for the packing, where we use memref.copy
, which gets converted to an XSMM's identity
(copy).
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
scf.forall (%arg1, %arg2) in (8, 4) {
%3 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg1)
%4 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg2)
%subview = memref.subview %arg0[%3, %4] [32, 32] [1, 1] : memref<256x128xf32> to memref<32x32xf32, strided<[128, 1], offset: ?>>
%subview_2 = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
// NOTE: The dispatch is the function that compiles the kernel and only needs to be run once.
// It will be hoisted and commoned up on a later pass
%5 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32
// This is identical to `memref.copy`
xsmm.unary identity(data_type = f32, %5, %subview, %subview_2) : (i64, memref<32x32xf32, strided<[128, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
}
This is the core transform that converts generic Linalg operations into XSMM calls. The previous pass only converts the packs to XSMM, this one converts everything else.
// Buffer 1
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>
// Layer 1
scf.forall (%arg1, %arg2) in (8, 8) {
// Tiles
%subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
// Fill with zero -> XSMM ZERO
// NOTE: This can be converted to BRGEMM(BETA=0) later
%4 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
xsmm.unary zero(data_type = f32, %4, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
%subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// BRGEMM
// NOTE: This can be fused into FUSED_BRGEMM(BIN=ADD, UN=RELU)
%5 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
xsmm.brgemm(data_type = f32, %5, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
// BIAS ADD
%6 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
xsmm.binary add(data_type = f32, %6, %2, %subview, %subview) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
// RELU
%7 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
xsmm.unary relu(data_type = f32, %7, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
}
NOTE: Still all dispatches are inside the inner loop and will be hoisted and commoned up later.
Since our use here means the same thing, and since the OpenMP pass only detects the latter, we need to convert to get OpenMP functionality.
// From
scf.forall (%arg1, %arg2) in (8, 8) {
...
}
// To
scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
...
scf.yield
}
This is the pass that hoists all loop-constant calls (our dispatches):
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32>
// All dispatches outside of the loop now
%4 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
%5 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
%6 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32
%7 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
// Layer 1 loop
scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) {
// C tile
%subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
// C = <0>
xsmm.unary zero(data_type = f32, %4, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
// Input tile
%subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// BRGEMM: only execution here
xsmm.brgemm(data_type = f32, %5, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
xsmm.binary add(data_type = f32, %6, %2, %subview, %subview) : (i64, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
xsmm.unary relu(data_type = f32, %7, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
scf.yield
}
This (future #752) pass looks at a sequence of XSMM operations and tried to fused them according to the target's rules.
LIBXSMM supports fused BRGEMMs, where a single function call handles the GEMM-like, a binary and a unary operation (configurable), and has flags such as BETA=0
that allows us to elide the initial xsmm.unary zero
.
NOTE: The initial zero
inside the loop isn't costly (operating on the tile, so cache-friendly), and the same can be said for the following element-wise operations.
However, this is only the case for simple patterns.
For example, if the element-wise don't bufferize correctly and add allocs
in between, fusing will remove that complexity and return performance to base levels.
TODO: Once the PR is merged, add IR for this pass.
Once all XSMM ops are lowered and fused, we convert them to function calls. These functions are available in our runtime library that marshals arguments and return values and calls LIBXSMM functions directly.
// From
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
%3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32
scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c4) step (%c1, %c1) {
...
xsmm.unary identity(data_type = f32, %3, %subview, %subview_2) : (i64, memref<32x32xf32, strided<[128, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
scf.yield
}
// To
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32>
%3 = call @xsmm_unary_dispatch(%c1_i64, %c1_i64, %c32_i64, %c32_i64, %c128_i64, %c32_i64, %c0_i64) : (i64, i64, i64, i64, i64, i64, i64) -> i64
scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c4) step (%c1, %c1) {
...
// NOTE: This marshalling is generated by the compiler, to extract MLIR memref semantics
// Setting up LIBXSMM's arguments and it's flags is done in the runtime shim layer
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : memref<32x32xf32, strided<[128, 1], offset: ?>> -> memref<f32>, index, index, index, index, index
%intptr_3 = memref.extract_aligned_pointer_as_index %subview : memref<32x32xf32, strided<[128, 1], offset: ?>> -> index
%17 = arith.index_cast %intptr_3 : index to i64
%18 = llvm.inttoptr %17 : i64 to !llvm.ptr<f32>
%base_buffer_4, %offset_5, %sizes_6:2, %strides_7:2 = memref.extract_strided_metadata %subview_2 : memref<32x32xf32, strided<[32, 1], offset: ?>> -> memref<f32>, index, index, index, index, index
%intptr_8 = memref.extract_aligned_pointer_as_index %subview_2 : memref<32x32xf32, strided<[32, 1], offset: ?>> -> index
%19 = arith.index_cast %intptr_8 : index to i64
%20 = llvm.inttoptr %19 : i64 to !llvm.ptr<f32>
// Invoke call
func.call @xsmm_unary_invoke(%c1_i64, %3, %18, %offset, %20, %offset_5) : (i64, i64, !llvm.ptr<f32>, index, !llvm.ptr<f32>, index) -> ()
scf.yield
}