Skip to content

Linalg Generic Nesting

Renato Golin edited this page Nov 9, 2023 · 1 revision

Linalg generic has a single region, representing the inner-most loop of the computation, and therefore can only represent perfectly nested loops.

This inhibits representation of more complex patterns (ex. convolution) and fusion of common kernels (ex. fully connected).

Fully Connected Layer

A fully connected layer is represented by a matrix multiply (the fully connected part) followed by some bias addition and an activation function (ex. maxf(x,0)).

While the bias and the activation can be element-wise, the matrix multiply clearly isn't (reduction in one dimension). For this reason, a Linalg generic cannot represent such a computation.

When a matrix multiply finishes accumulating into a value (or tile), the following element-wise operations can already be performed in place, before the rest of the matrices are multiplied together. This leads to better locality, since there's no need to fetch the accumulation result once again. Thus tiling the reduction loop of the matrix multiplication with the following element-wise operations increases performance.

Loop Representation

The naive loop representation would be:

# Init
C[M, N] = 0.0

# Matmul
for (i, k, j) in (M, K, N):
    C[i, j] += A[i, k] * B[k, j]

# Bias Add
for (i, j) in (M, N):
    C[i, j] += Bias[i, j]

# Activation
for (i, j) in (M, N):
    C[i, j] = maxf(C[i, j], 0.0)

If the matrix is large enough (C >> cache), none of the operations will read from cache, because previous C elements will have been evicted already.

Blocking increases locality:

# Tile size = [ m, k, n ]
# Blocked Matmul
for (ib, jb) in (M/m, N/n):
    # Tile-local init (zero & prefetch)
    C[ib, jb] = 0.0

    # Outer reduction
    for (kb) in (K/k):
        # Tile matmul, reduces on C[ib, jb] tile
        for (i, k, j) in (m, k, n):
            C[i, j] += A[i, k] * B[k, j]

# Blocked Bias Add
for (ib, jb) in (M/m, N/n):
    # Tile element-wise add
    for (i, j) in (m, n):
        C[i, j] += Bias[i, j]

# Blocked Activation
for (ib, jb) in (M/m, N/n):
    # Tile element-wise activation
    for (i, j) in (M, N):
        C[i, j] = maxf(C[i, j], 0.0)

Now the blocked matmul keeps the C[ib, jb] tile in cache to both read from and write to during accumulation, while the tiles A[ib, kb] and B[kb, mb] are read from higher cache or memory every time.

The blocking has no effect in locality for the element-wise operations, since there's no accumulation. However, it allows us to detect a fusion pattern. Note that the outer loops are the same on all three operations, and that the tile that is being written to (C[ib, ij]) is the same on all three loops.

With this transformation, we can fuse all three operations in a single loop, allowing the two remaining element-wise operations to read from and write to the same cached tile from the matrix multiplication operation:

# Tile size = [ m, k, n ]
# Blocked Matmul
for (ib, jb) in (M/m, N/n):
    # Tile-local init (zero & prefetch)
    C[ib, jb] = 0.0

    # Outer (tile) reduction
    for (kb) in (K/k):
        # Tile matmul, reduces on C[ib, jb] tile
        for (i, k, j) in (m, k, n):
            C[i, j] += A[i, k] * B[k, j]

    # Tile element-wise add + activation on C[ib, jb] tile
    for (i, j) in (m, n):
        C[i, j] += Bias[i, j]
        C[i, j] = maxf(C[i, j], 0.0)

Note that each [ib, jb] loop only reads from and writes to the same C[ib, jb] tile, and only reads from the other tensors, so there is no write-dependency between the outer loops.

Therefore, the outer loops are completely parallel and can execute even on different threads (OpenMP, GPU), while the zero-init acts as a prefetch for the entire accumulation tile that, if chosen wisely, never leaves the cache.

This helps reduce the memory cost of the matmul and makes all element-wise operations to be done on registers/cache. But now we have a loop nest structure that isn't perfectly nested, ie. there are multiple loops inside loops at different iteration points.

Linalg Generic Representation

Linalg generic can easily represent the non-blocked versions of these operations:

// Matmul (d0, d1, d2) == (M, N, K)
#a-map = affine_map<(d0, d1, d2) -> (d0, d2)>
#b-map = affine_map<(d0, d1, d2) -> (d2, d1)>
#c-map = affine_map<(d0, d1, d2) -> (d0, d1)>
%1 = linalg.generic
        {indexing_maps = [#a-map, #b-map, #c-map],
         iterator_types = ["parallel", "parallel", "reduction"]}
    ins(%A, %B : tensor<MxKxf32>,
                 tensor<KxNxf32>)
    outs(%C : tensor<MxNxf32>) {
    ^bb0(%a: f32, %b: f32, %c: f32) :
        %d = arith.mulf %a, %b: f32
        %e = arith.addf %c, %d: f32
        linalg.yield %e : f32
}

// Element-wise
#ew-map = affine_map<(d0, d1) -> (d0, d1)>
%2 = linalg.generic
        {indexing_maps = [#ew-map, #ew-map, #ew-map],
         iterator_types = ["parallel", "parallel"]}
    ins(%0, %arg2 : tensor<MxNxf32>,
                    tensor<MxNxf32>)
    outs(%arg3 : tensor<MxNxf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
        %add = arith.addf %in, %in_0 : f32
        linalg.yield %add : f32
} -> tensor<MxNxf32>

%cst = arith.constant 0.000000e+00 : f32
%3 = linalg.generic
        {indexing_maps = [#ew-map, #ew-map],
         iterator_types = ["parallel", "parallel"]}
    ins(%2 : tensor<MxNxf32>)
    outs(%arg3 : tensor<MxNxf32>) {
    ^bb0(%in: f32, %out: f32):
        %max = arith.maxf %in, %cst : f32
        linalg.yield %max : f32
} -> tensor<MxNxf32>

But because the blocked version contains non-trivial nesting, the requirement that the loops be perfectly nested doesn't apply, and we cannot represent that kind of fusion in a generic operation.

The main problems are:

  1. There is only one region which encodes the inner-most loop nest. If we want to have more, we need more regions and a way to combine them logically with the rest of the operation's parameters (maps, iterators).
  2. The affine maps and iterators are applicable to the single region. If we want multiple regions, we need a way to compose multiple maps and reason about their compatibility in different nest levels.

So, to solve this problem, we need the following steps:

  1. Design a representation for multiple regions and how they next together. This includes region syntax and the breaking of maps and iterators.
  2. Design an affine map semantics that allows composition at different insertion points (nest levels). Given that loops are hierarchical (tree-like), we just need to compose addition (on top of), not necessarily intersection (of two nests on the same level).
  3. Create a set of affine maps transformations that allow us to merge regions from different generic operations into each other at different nest levels.

Multiple Region Syntax

A packed Fully Connected layer as a Linalg generic looks like:

// Blocked affine maps:
// (d0, d1, d2), (d3, d4, d5) == (MB, NB, KB), (m, n, k)
// Packed matrices, so C = A * BT (with B columns contiguous in memory)

// Blocked A map: (MB, KB, m, k)
#a-map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// Blocked Transposed B map: (NB, KB, k, n)
#b-map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
// Blocked C map: (MB, NB, m, n)
#c-map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

// Blocked matmul
%0 = linalg.generic
    {indexing_maps = [#a-map, #b-map, #c-map],
        iterator_types = [
            // MB, NB, KB
            "parallel", "parallel", "reduction",
            // m, n, k
            "parallel", "parallel", "reduction"]}
    ins(%A, %B : tensor<MBxKBx(m)x(k)xf32>, tensor<NBxKBx(k)x(n)xf32>)
    outs(%C : tensor<MBxNBx(m)x(n)xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
        %4 = arith.mulf %in, %in_2 : f32
        %5 = arith.addf %out, %4 : f32
        linalg.yield %5 : f32
} -> tensor<MBxNBx(m)x(n)xf32>

// Element-wise map is identity
#ew-map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

// Bias Add
%1 = linalg.generic
    {indexing_maps = [#ew-map, #ew-map, #ew-map],
    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
    ins(%0, %BIAS : tensor<MBxNBx(m)x(n)xf32>, tensor<MBxNBx(m)x(n)xf32>) outs(%pack_2 : tensor<MBxNBx(m)x(n)xf32>) {
    ^bb0(%in: f32, %in_9: f32, %out: f32):
        %11 = arith.addf %in, %in_9 : f32
        linalg.yield %11 : f32
    } -> tensor<MBxNBx(m)x(n)xf32>

// Activation
%cst = arith.constant 0.000000e+00 : f32
%2 = linalg.generic
    {indexing_maps = [#ew-map, #ew-map],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
     ins(%1 : tensor<MBxNBx(m)x(n)xf32>) outs(%pack_2 : tensor<MBxNBx(m)x(n)xf32>) {
    ^bb0(%in: f32, %out: f32):
        %7 = arith.maxf %in, %cst : f32
        linalg.yield %7 : f32
} -> tensor<MBxNBx(m)x(n)xf32>

return %2 : tensor<MBxNBx(m)x(n)xf32>

Note: Packing a matrix for faster matmul means blocking it into tiles and transposing those tiles to make the B matrix read contiguous in memory. This is why the reduction is on MBxKB x NBxKB and not KBxNB.

Element-wise fusion

Fusing the two element-wise operations is trivial to Linalg, as they use the same affine map, iterator types and the second generic fully consumers the results of the first, element-wise.

This means that, for every element of the first operation, the second can happen immediately afterwards without needing a second loop. This is the same fusion that we see in pseudo-code above.

// Element-wise map is identity
#ew-map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

// Bias Add + Activation
%cst = arith.constant 0.000000e+00 : f32
%7 = linalg.generic
    {indexing_maps = [#ew-map, #ew-map, #ew-map],
    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
    ins(%pack_3, %pack_4 : tensor<MBxNBx(m)x(n)xf32>, tensor<MBxNBx(m)x(n)xf32>) outs(%pack_2 : tensor<MBxNBx(m)x(n)xf32>) {
    ^bb0(%in: f32, %in_9: f32, %out: f32):
        %11 = arith.addf %in, %in_9 : f32
        %12 = arith.maxf %11, %cst : f32
        linalg.yield %12 : f32
    } -> tensor<MBxNBx(m)x(n)xf32>

But now we cannot fuse this operation with the matmul because the affine maps and iterators are different. So we need a way to represent nesting with the same maps and iterators, and then inside the nested regions, have additional maps that compose with the outer maps to form the full operation.

Tensor of Tensors

If we try to create a partial affine map for the type tensor<MBxKBx(m)x(n)xf32> we'll run into notation problems. For example, how to separate the parallel outer loops (MBxKB) from the GEMM inner loop (mxn)?

One way would be to use floordiv m and floordiv n, but this would create a dependency between dimensions that can't be tracked once persisted (ie. we can only guess that floordiv 32 is related to the inner m).

An easier way would be to split the tensor into outer and inner parts, with the inner part as a tensor or vector. During tiling, the outer tensor would have an affine map related to its outer dims, and the linalg ops inside the region would have affine maps related to the inner dims, then with scalar types (or smaller tensors).

For example:

// Original packed type
tensor<MB x KB x m x k x f32>
// Original map: (MB, KB, m, k)
#a-map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>

// Would be:
tensor<MB x KB x tensor<m x k x f32>>
// With outer map
#a-map = affine_map<(d0, d1, d2) -> (d0, d2)>
// And inner map
#a-map = affine_map<(d0, d1, d2) -> (d0, d2)>
// Which is the same, as we're tiling rows/cols in blocks

Multi-level fusion

Following the fusion style from the pseudo-code above, a fused matmul and element-wise would have:

  • Two parallel loops (the outer blocks MBxNB)
  • One outer-reduction loop for the matmul, with a tile matmul inside
  • One tile element-wise for both bias add and activation

Ignoring what the maps look like for now, this would be the shape:

%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.generic
    {indexing_maps = [#MNB-map, #MNB-map],
     iterator_types = [ "parallel", "parallel" ]}
    ins(%A, %B : tensor<MB x KB x tensor<m x k x f32>>, tensor<NB x KB x tensor<k x n x f32>>)
    outs(%C : tensor<MB x NB x tensor<m x n xf32>>) {
    // NOTE: Arguments are no longer scalar
    ^bb0(%a: tensor<m x k x f32>, %b: tensor<k x n x f32>, %c: tensor<m x n x f32>):

        // NOTE: Region operations are no longer scalar
        %mm = linalg.generic
                {indexing_maps = [ #KB-map,
                                   #a-map, #b-map, #c-map ],
                iterator_types = [ "reduction",
                                   "parallel", "parallel", "reduction" ]}
                ins(%a, %b : tensor<m x k x f32>, tensor<k x n x f32>)
                outs(%c : tensor<m x n x f32>) {
                ^bb0(%in: f32, %in_2: f32, %out: f32):
                    %4 = arith.mulf %in, %in_2 : f32
                    %5 = arith.addf %out, %4 : f32
                    linalg.yield %5 : f32
                // NOTE: Yielding the current tile shape
                } -> tensor<m x n x f32>

        // This is common in `scf.for` patterns
        %bias = tensor.extract_slice %BIAS[MB, NB, 1, 1][1, 1, 1, 1] : tensor<m x n x f32>

        // NOTE: Another "tile" generic with _PARALLEL_ maps
        %act = linalg.generic
                {indexing_maps = [ #ew-map, #ew-map],
                iterator_types = [ "parallel", "parallel" ]}
                ins(%mm, %bias : tensor<m x n x f32>, tensor<m x n x f32>)
                outs(%C : tensor<m x n x f32>) {
                ^bb0(%in: f32, %in_2: f32, %out: f32):
                    %11 = arith.addf %in, %in_2 : f32
                    %12 = arith.maxf %11, %cst : f32
                    linalg.yield %5 : f32
                // NOTE: Yielding the current tile shape
                } -> tensor<m x n x f32>
        linalg.yield %act : tensor<m x n x f32>

// NOTE: Yielding the current overall shape
} -> tensor<MB x NB x m x n x f32>

In summary, the requirements over existing generic semantics would be:

  • Allow non-scalar block arguments and operations (including other generic ops)
  • A way to represent tile shapes using non-scalar affine-maps (tensor, memref, vector)
  • Add verification for maps/iterators and yield shapes
  • Support fusion by breaking affine maps / iterator sequences into common denominators and adding the remainders inside the region as another generic.

Pros:

  • Able to keep affine maps and iterator semantics for longer when fusing
  • Able to verify shape semantics, as they all must match with maps and iterators (unlike scf.for family)
  • Can mix and match scalar and tile operations under the same generic region (as long as shapes and maps match)

Cons:

  • Complexity in verification with non-scalar operations
  • Looks "too much like an scf.for structure", may not have enough "pros" to justify complexity
  • Will need a whole new lowering stage to scf operations

Element-wise Slicing

While element-wise operations can be directly fused if their maps are identical, there are two cases where you may want to slice their maps:

  1. When the element-wise operations aren't natively on the same elements (ex. different dimensions), but can still be fused at a tile level (ex. broadcast).
  2. When there are non-element-wise operations in the use-def chain (ex. matmul) and you want to fuse the element-wise tiles with the block reduction of the matmul (example above).

Starting with the simplest, let's look at a broadcast add plus an activation.

// Element-wise map is identity
#ew-map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#bc-map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>

%BIAS = ... : tensor<NBx(n)xf32>

// Bias Add
%1 = linalg.generic
    {indexing_maps = [#ew-map, #bc-map, #ew-map],
    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
    ins(%0, %BIAS : tensor<MBxNBx(m)x(n)xf32>, tensor<NBx(n)xf32>) outs(%pack_2 : tensor<MBxNBx(m)x(n)xf32>) {
    ^bb0(%in: f32, %in_9: f32, %out: f32):
        %11 = arith.addf %in, %in_9 : f32
        linalg.yield %11 : f32
    } -> tensor<MBxNBx(m)x(n)xf32>

// Activation
%cst = arith.constant 0.000000e+00 : f32
%2 = linalg.generic
    {indexing_maps = [#ew-map, #ew-map],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
     ins(%1 : tensor<MBxNBx(m)x(n)xf32>) outs(%pack_2 : tensor<MBxNBx(m)x(n)xf32>) {
    ^bb0(%in: f32, %out: f32):
        %7 = arith.maxf %in, %cst : f32
        linalg.yield %7 : f32
} -> tensor<MBxNBx(m)x(n)xf32>

return %2 : tensor<MBxNBx(m)x(n)xf32>

The main complication here is that the %BIAS tensor was not 2D to begin with, but 1D. Thus, when blocked, it only had one dimension to tile. This dimension will be read over and over for every tile column in the source argument %0, so it's still a tile operation on the "same" iteration space, but broadcasted to fix the larger shape.

The obvious "solution" would be to broadcast the smaller shape into the larger shape and realize that the affine map end up being the same. But that means creating a new tensor (alloc + copy), which is expensive.

But we can see that this is a broadcast by looking at the two maps:

  • (d0, d1, d2, d3) vs (0, d1, 0, d3)

The matching positions for d1 and d3 as well as the constants for d0 and d2 tell us that we can create two broadcast maps for the element-wise blocked version:

%BIAS = ... : tensor<NBx(n)xf32>
%cst = arith.constant 0.000000e+00 : f32

// Original blocked map: (MB, NB)
#MNB-map = affine-map <(d0, d1, d2, d3) -> (d0 floordiv m, d1 floordiv n, d2, d3)>
#MNB-bc-map = affine-map <(d0, d1, d2, d3) -> (0, d1 floordiv n, 0, d3)>

%0 = linalg.generic
    // NOTE: One is the standard MNB while the other is the broadcast version
    {indexing_maps = [#MNB-map, #MNB-bc-map],
     iterator_types = [ "parallel", "parallel" ]}
    ins(%A, %B : tensor<MBxNBxtensor<m x k x f32>>, tensor<NBx(n)xf32>) <--- This is WRONG!
    outs(%C : tensor<MBxNBx(m)x(n)xf32>) {


// Bias Add
%1 = linalg.generic
    {indexing_maps = [#MNB-map, #MNB-bc-map, #ew-map],
    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
    ins(%0, %BIAS : tensor<MBxNBx(m)x(n)xf32>, tensor<NBx(n)xf32>) outs(%pack_2 : tensor<MBxNBx(m)x(n)xf32>) {
    }