Skip to content

Rising Linalg to TPP on Tensor

Renato Golin edited this page Jun 15, 2023 · 3 revisions

This document aims to collect the patterns we're looking for when rising linalg generic operations to TPP operations on tensor.

Because TPP on tensors does not use Destination-Passing-Style (DPS) and linalg does, we need to make sure the patterns we find are convertible without loss of information.

Basic Structure

The basic linalg generic structure is:

%out = linalg.generic
            {
              indexing_maps = [ #affine-maps ... ],
              iterator_types = [ "parallel" || "reduction" ... ]
            }
            ins (%arg1, %arg2, ...) // Tensor types
            outs (%shape || %init)  // Tensor types
        {
          ^bb0(%block_args ...):    // Scalar types
            %0 = scalar_op1
            %1 = scalar_op2
            ...
            linalg.yield %res       // Scalar type (into %out's element)
        } -> // Tensor types

If generics use the outs arguments, it's treated as "initialisation". Otherwise, it's just an indication of "shape".

Initialisation example

  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<8x32x32x32xf32>, tensor<32x32x32x32xf32>) outs(%arg3 : tensor<8x32x32x32xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %mul = arith.mulf %in, %in_0 : f32
      %add = arith.addf %out, %mul : f32  // Note: %out is an element of %arg3
      linalg.yield %add : f32
  } -> tensor<8x32x32x32xf32>

This is equivalent to:

  • %0 = linalg.matmul ins(%arg0, %arg1) outs(%arg3)
  • D = A x B + C (if out-of-line)
  • C = A x B + C (if in-line)

The selection of in/out of line will depend on the liveness of %arg3 after the operation. DPS at tensor level aims to hint the compiler of certain patterns, but it's not a requirement to reuse particular buffers, especially if the compiler can prove the variable using that buffer isn't dead yet.

Shape example

  %6 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4, %expanded2 : tensor<8x32x32x32xf32>, tensor<32x32xf32>) outs(%arg6 : tensor<8x32x32x32xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %add = arith.addf %in, %in_0 : f32 // Note: %out is not used here, so %arg6 is untouched
      linalg.yield %add : f32
  } -> tensor<8x32x32x32xf32>

This is equivalent to:

  • C = A + B (if out-of-line)
  • A = A + B (if in-line, reusing the first buffer)
  • B = A + B (if in-line, reusing the second buffer)

Because the IR is in SSA form, the programming language above is responsible to track its own variables and make sure that its variables are used correctly. This is still just a bufferization issue and not a semantics issue. The only reasoning behind reusing existing buffers is about safety, correctness and performance, not program semantics.

DPS to non-DPS

When using destination passing style, linalg creates empty tensors to define the shape or initialisation of the operations below.

Trivial Case

When the empty tensor is just an artefact of DPS.

For example:

  // This is an empty tensor to hint buffer reuse. It does not guarantee allocation. Reading %0 is undefined behaviour (uninitialised).
  %0 = tensor.empty() : tensor<32x32xf32>
  // Now this puts zeros into it, so now we need an allocation
  %1 = linalg.fill ins(%const0 : f32) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
  // And this uses the zero tensor as initialisation for the matmul
  %2 = linalg.matmul ins(%arg0, %arg1) outs(%1) -> tensor<32x32xf32>

It's easy to see that in non-DPS, this would simply be:

  // This must create a new buffer
  %0 = tpp.zero : tensor<32x32xf32>
  // This reads and writes to that buffer
  %1 = tpp.gemm (%arg0, %arg1, %0) : tensor<32x32xf32>

Buffer reuse

When two or more linalg ops create their own empty but the buffer can be reused.

For example:

  // New tensor, for the zero: OK
  %0 = tensor.empty() : tensor<32x32xf32>
  %1 = linalg.fill ins(%const0 : f32) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
  %2 = linalg.matmul ins(%arg0, %arg1) outs(%1) -> tensor<32x32xf32>

  // New tensor for the temporary result of the ADD below. Can be reused if %2 is dead.
  %3 = tensor.empty() : tensor<32x32xf32>
  %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2, %2 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%3 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %add = arith.addf %in, %in_0 : f32
      linalg.yield %add : f32
  } -> tensor<32x32xf32>

The second empty above (%3) can be reused if the result of the matmul (%2) is not used after the add.

  // New tensor, for the zero: OK
  %0 = tensor.empty() : tensor<32x32xf32>
  %1 = linalg.fill ins(%const0 : f32) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
  %2 = linalg.matmul ins(%arg0, %arg1) outs(%1) -> tensor<32x32xf32>

  // Reusing the same buffer as matmul, note that outs == %2
  %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg2, %2 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%2 : tensor<32x32xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %add = arith.addf %in, %in_0 : f32
      linalg.yield %add : f32
  } -> tensor<32x32xf32>

In buffer semantics, this is the same as saying that %0, %1, %2 and %3 all "alias" with each other, which is valid if-and-only-if the compiler can show that all of those variables are dead after the first use.

In non-DPS, both cases should be converted to:

  %0 = tpp.zero : tensor<32x32xf32>
  %1 = tpp.gemm (%arg0, %arg1, %0) : tensor<32x32xf32>
  %2 = tpp.add (%arg2, %1) : tensor<32x32xf32>

And bufferization can take case of the reuse.

Unary Operations

There are two kinds of unary operations in TPP: compute and transform.

Compute Operations

These are unary operations that execute some arithmetic, mathematical or machine learning functions on the single input, returning a single output. These operations need the types to be identical and can almost always be used in-place (depends on liveness etc).

Examples:

  %0 = tpp.square (%arg) : tensor<32x32xf32>
  %1 = tpp.exp(%0) : tensor<32x32xf32>
  %2 = tpp.relu(%1) : tensor<32x32xf32>

As above, regardless of an empty being defined, any generic that only reads one input, yields to the output without writing to outs and has the right operation inside the region should match directly to a TPP op.

  %0 = tensor.empty() : tensor<32x32xf32>
  %1 = linalg.fill ins(%const0 : f32) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
  %2 = linalg.matmul ins(%arg0, %arg1) outs(%1) -> tensor<32x32xf32>

  // DPS out-of-place
  %3 = tensor.empty() : tensor<32x32xf32>
  %4 = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] }
    ins(%2: tensor<32x32xf32>) {       // Note: %2 is the result of the matmul
    outs(%3: tensor<32x32xf32>) {      // Note: %3 is the empty that is never used (ie. used as "shape")
      ^bb0(%a: f32, %b: f32):
        %0 = arith.maxf %a, %c0 : f32  // Note: Reads from ins, does not write to outs (%a is an element of %2)
        linalg.yield %0: f32           // Note: %b is never read from/written to
  } -> tensor<32x32xf32>

  // DPS in-place
  %5 = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] }
    outs(%2: tensor<32x32xf32>) {      // Note: %2 is the result of the matmul
      ^bb0(%a: f32):
        %0 = arith.maxf %a, %c0 : f32  // Note: Reads from outs, not write to it (%a is an element of %2)
        linalg.yield %0: f32
  } -> tensor<32x32xf32>

Both %4 and %5 above should match to just:

  %0 = tpp.zero : tensor<32x32xf32>
  %1 = tpp.gemm (%arg0, %arg1, %0) : tensor<32x32xf32>
  %2 = tpp.relu (%1) : tensor<32x32xf32>

Here, tpp.relu is being treated as out-of-place because by SSA rules, %2 is a new value. But bufferization (just like register allocation) can see if %1 is unused after %2 is defined (liveness analysis) and reuse the same buffer at memref level.

Transform Operations

Any operation that changes the shape of the tensor is a transform operation. The three main flavours are: broadcast, reduce and transpose.

These operations can mean two things:

  1. An out-of-place operation from a smaller buffer / scalar into a larger buffer.
  2. An annotation to an operand of another TPP operation that supports "broadcastable/transposed reads" / "reducible/transposed writes".

In the first case, there is always a new buffer being created because they have different sizes. In the second case, there is no need to create a new buffer because the following operation will read from / write to smaller/different buffers.

The latter optimisation is only valid if the broadcasted value is only ever used by operations that can perform that read. If at least one does not, we have to materialise the tensor.

Broadcast example:

  // Map "oneD" repeats the same row over all columns
  #oneD = affine_map<(d0, d1) -> (d1)>
  #twoD = affine_map<(d0, d1) -> (d0, d1)>

  %0 = tensor.empty() : tensor<32x32xf32>
  // Takes 1D tensor and "splats" into a 2D tensor
  %1 = linalg.generic {indexing_maps=[#oneD, #twoD],
                       iterator_types = ["parallel", "parallel"]}
    ins(%input_tensor : tensor<32xf32> ) outs(%0 : tensor<32x32xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32      // Note: does not explicitly writes to outs but shape is different
  } -> tensor<32x32xf32>
  %0 = tensor.empty() : tensor<32x32xf32>
  // Same logic, named op
  %1 = linalg.broadcast ins(%input_tensor : tensor<32xf32>)
                        outs(%0: tensor<32x32xf32>)
                        dimensions = [0, 1]

They're both equivalent to:

  %0 = tpp.broadcast(%input_tensor : tensor<32xf32>) : tensor<32x32xf32>

As annotation to a BRGEMM, the broadcast can be "fused" into following operation by annotating the operand that is broadcasted:

  // GEMM buffer
  %0 = tpp.zero : tensor<32x32xf32>
  // GEMM
  %1 = tpp.brgemm (%arg0, %arg1, %0) : tensor<32x32xf32>
  // Broadcasted bias
  %bias = tpp.broadcast (%arg2 : tensor<32xf32>) : tensor<32x32xf32>
  // Bias Add
  %2 = tpp.add (%bias, %1) : tensor<32x32xf32>
  // ReLU
  %3 = tpp.relu (%2) : tensor<32x32xf32>

When lowering to BRGEMM, because the accumulation matrix (%0) was zero, we can move the bias into the accumulation (D = 0 && D += A x B + C -> C += A x B), and broadcast the bias before the GEMM (avoids an add and a zero).

  // New buffer
  %bias = tpp.broadcast (%arg2 : tensor<32xf32>) : tensor<32x32xf32>
  // Yet another buffer?
  %0 = fused_brgemm(%arg0, %arg1, %bias, ...) [ unary=relu ] : tensor<32x32xf32>

Here, the broadcast allocates a new tensor and the result of the BRGEMM (%0) can reuse the same buffer from the broadcast (%bias), avoiding one allocation. It would convert to something equivalent to:

  // Invalid IR, xsmm doesn't operate a tensor, but you get the idea
  %0 = xsmm.unary identity(%arg2) [ (broadcast_col0) ] : tensor<32x32xf32>
  %1 = xsmm.fused_brgemm(%arg0, %arg1, %0, ...) [ unary=relu ] : tensor<32x32xf32>

However, if the accumulating matrix isn't zero to begin with, we can't just move the bias up and we need to add it at the end.

  // Previous op, not zero -- This is the ONLY difference!
  %0 = tpp.something (...) : tensor<32x32xf32>

  // GEMM
  %1 = tpp.brgemm (%arg0, %arg1, %0) : tensor<32x32xf32>
  // Broadcasted bias
  %bias = tpp.broadcast (%arg2 : tensor<32xf32>) : tensor<32x32xf32>
  // Bias Add
  %2 = tpp.add (%bias, %1) : tensor<32x32xf32>
  // ReLU
  %3 = tpp.relu (%2) : tensor<32x32xf32>

Now, this does not map to zero initializer, so we need to fuse the add too, but we still need the broadcast:

  // Doesn't need to materialize, because the fused_brgemm has the "broadcast" flag on this operator
  %bias = tpp.broadcast (%arg2 : tensor<32xf32>) : tensor<32x32xf32>
  // Our first buffer allocation is here
  %0 = fused_brgemm(%arg0, %arg1, %bias, ...) [ binary=add, unary=relu ] : tensor<32x32xf32>

Here, %bias does not allocate anything. When lowering fused_brgemm to the xsmm dialect, the conversion function will check if the producer of %bias is a tpp.broadcast, and if it is, it will convert to something equivalent to:

  // Invalid IR, xsmm doesn't operate a tensor, but you get the idea
  %0 = xsmm.fused_brgemm(%arg0, %arg1, %arg2, ...) [ binary=add (broadcast_col0), unary=relu ] : tensor<32x32xf32>

Reduction and transpose operations behave in the same way.

The reasons we don't have a flag at tpp level in the first place are:

  • Flags belong to operations, not operands.
  • Operand flags can create a combinatorial of operation level flags:
    • trans_arg0, trans_arg1, ...
    • vnni_arg0, vnni_arg1, ...
    • bcast_arg0, bcast_arg1, ...
  • We still need the materializeable operations (broadcast, reduce, transpose) anyway.

At xsmm level, the operations are just function calls, which need support all possible flags from libxsmm, so we need to carry them anyway, so we convert tpp operations as either xsmm calls or flags into other calls.

Binary Operations

All TPP binary operations are element-wise compute operations on 2D tiles.

XSMM supports broadcasting/reducing flags for binary calls but because it will depend on which operand it applies to, and could be fused in a BRGEMM nearby, we prefer to carry that information in the operand, as a unary op (tpp.broadcast, see above). Rising linalg generic operations to TPP requires understanding those subtleties.

The basic pattern is:

  // Out-of-line tensor
  %0 = tensor.empty() : tensor<32x32xf32>
  // Parallel element-wise map
  #map = affine_map<(d0, d1) -> (d0, d1)>
  // Single op inside
  %1 = linalg.generic {indexing_maps = [#map, #map, #map], 
                       iterator_types = ["parallel", "parallel"]}
           ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32x32xf32>) 
           outs(%0: tensor<32x32xf32>) {         // Note: %0 never read/rwitten
       ^bb0(%in: f32, %in_1: f32, %out: f32):
          %0 = arith.addf %in, %in_1 : f32       // Adds element-wise
          linalg.yield %0 : f32                  // Note: %out never read/rwitten
  } -> tensor<32x32xf32>

Since a tensor operation always creates a new tensor, the empty is irrelevant and this lowers to:

  %1 = tpp.add(%arg0, %arg1) : tensor<32x32xf32>

If the %0 was created by another operation (in-place), this is still a job for bufferization and we don't need to make it explicit.

However, there are two alternatives when it comes to affine maps that indicate broadcast or reduction:

  // Out-of-line tensor
  %0 = tensor.empty() : tensor<32x32xf32>
  // Parallel and broadcast element-wise map
  #parallel-map = affine_map<(d0, d1) -> (d0, d1)>
  #bcast-map = affine_map<(d0, d1) -> (d0)>
  // Single op inside
  %1 = linalg.generic {indexing_maps = [#parallel-map, #bcast-map, #parallel-map], 
                       iterator_types = ["parallel", "parallel"]}
           ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32xf32>) // Note: %arg0 is 2D while %arg1 is 1D
           outs(%0: tensor<32x32xf32>) {         // Note: %0 never read/rwitten
       ^bb0(%in: f32, %in_1: f32, %out: f32):
          %0 = arith.addf %in, %in_1 : f32       // Adds element-wise
          linalg.yield %0 : f32                  // Note: %out never read/rwitten
  } -> tensor<32x32xf32>

This means "broadcast on %arg1 before adding", which needs an allocation or "read the same row every time" which doesn't. Both these semantics are supported by xsmm.

Here's the equivalent:

  // %arg1 is tensor<32xf32>
  %0 = tpp.broadcast(%arg1) : tensor<32x32xf32>
  // Note: both operands and the return value are 2D
  %1 = tpp.add(%arg0, %0) : tensor<32x32xf32>

If the broadcast map was applied to the first argument (%arg0), this would be:

  // %arg0 is tensor<32xf32>
  %0 = tpp.broadcast(%arg0) : tensor<32x32xf32>
  // Note: both operands and the return value are 2D
  %1 = tpp.add(%0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) : tensor<32x32xf32>

Reductions are similar, but the affine map that is different is in the output.

Also note that there are multiple types of broadcasting/reduction:

  • Scalar <-> ND
  • 2 x 1D (row, column) <-> ND
  • 3 x 2D (xy, yz, zx) <-> ND
  • etc.

All of that assumes the dimensions are multiple of each other. Padding would add another operation (tpp.pad).

Transposes retain their dimensionality, but the dimensions are flipped.

The affine maps for 2D transpose is:

// Parallel input
 Input: <(d0, d1) -> (d0, d1)>

// Transposed output
Output: <(d0, d1) -> (d1, d0)>

But you can transpose an ND tensor in any two dimensions (ex. block-transpose):

// Parallel input
 Input: <(d0, d1, ..., dj, ..., dk, ..., dn) -> (d0, d1, ..., dj, ..., dk, ..., dn)>

// Block-transposed output (not k <-> j)
Output: <(d0, d1, ..., dj, ..., dk, ..., dn) -> (d0, d1, ..., dk, ..., dj, ..., dn)>