Skip to content

Linalg Dialect Wish list

Renato Golin edited this page Aug 4, 2024 · 24 revisions

Main topics we need to fix in the linalg dialect.

Short Term Issues

Explicit broadcast

We want explicit broadcast semantics to linalg ops and not rely on some vague implicit semantics, which can change depending on the input language. Because linalg is an optimization dialect, it should have no strong opinion on semantics, but should be able to represent them all in a compact way.

Broadcasts should also be optional. If omitted, no broadcast is possible and the input and output shapes should all be identical.

There are two main ways to represent (optional) explicit broadcasts:

  • Direct attributes ({ broadcast_a = [1], broadcast_b = [0] }): Force broadcast on a specific dim. Easy to read and parse. 1:1 with affine maps.
  • Affine maps ({ map_a = { (a0, a1) -> (a1) }, map_b { (b0, b1) -> (b0) } }): Explicit maps per operand. No ambiguity.

The first way is just syntactic sugar for the second, there is no change in semantics.

Problem Statement

Implicit broadcast can be confusing:

  // According to NumPy broadcast rules, this is an acceptable implicit broadcast
  %add = linalg.add ins(%0 : tensor<64xf32>, %1 : tensor<16x64xf32) 
                    outs(%buf : tensor<16x64xf32>) : tensor<16x64xf32>

  // Is this a valid broadcast?
  %add = linalg.add ins(%0 : tensor<16xf32>, %1 : tensor<16x64xf32) 
                    outs(%buf : tensor<16x64xf32>) : tensor<16x64xf32>

  // This definitely is:
  %bcast1 = linalg.broadcast ins(%0 : tensor<16xf32>)
                         inits(%buf : tensor<16x64xf32>) dimensions = [1] : tensor<16x64xf32>

  // What about this? Which dim should be broadcast?
  %add = linalg.add ins(%0 : tensor<16xf32>, %1 : tensor<16x16xf32) 
                    outs(%buf : tensor<16x16xf32>) : tensor<16x16xf32>

  // Explicit unitary dimension helps guide implicit, but canonicalization may reduce it
  // Which should be still fine as it doesn't remove the inner dims, only outer ones.
  %add = linalg.add ins(%0 : tensor<1x16xf32>, %1 : tensor<16x16xf32) 
                    outs(%buf : tensor<16x16xf32>) : tensor<16x16xf32>

Solutions

Currently, linalg.broadcast has explicit dimensions, which help.

So we could have a similar solution:

  %add = linalg.add ins(%0 : tensor<16xf32>, %1 : tensor<16x64xf32) 
                    outs(%buf : tensor<16x64xf32>)
                    broadcast_a = [1]
                   : tensor<16x64xf32>

Or in its more explicit affine map syntax:

  %add = linalg.add ins(%0 : tensor<16xf32>, %1 : tensor<16x64xf32) 
                    outs(%buf : tensor<16x64xf32>)
                    { map_a = { (d0, d1) -> (d0) } }
                    : tensor<16x64xf32>

In the end, implicit broadcast rules can be represented in MLIR without attributes or affine maps, but tensor reshapes must be introduced to row vectors (<M> -> <M,1>) to make it explicit, and that has propagation effects throughout the graph.

Explicit transpose

We want explicit transpose semantics to linalg ops and not need to create each new op per transpose type.

Transpose should also be optional. If omitted, no transpose is possible and the input and output shapes should all be identical.

There are three main ways to represent (optional) explicit transposes:

  • Direct attributes ({ transpose_a, transpose_b }): 1:1 mapping with current transposed ops. Only really works on 2D cases.
  • Permute attributes ({ permute_a = [1, 0], permute_b = [2, 0, 1] }): Allows any dim permutation. 1:1 with affine maps.
  • Affine maps ({ map_a = { (a0, a1) -> (a1, a0) }, map_b { (b0, b1, b2) -> (b2, b0, b1) } }): Explicit maps per operand. No ambiguity.

The first way is broken. We should move that to a more general approach.

The second way is just syntactic sugar for affine maps, there is no change in semantics.

Problem Statement

Transposes have some implicit semantics based on tensor shapes:

  // This could be interpreted as a transpose: A^T + B -> C
  // But it could also be interpreted as (A + B^T)^T -> C
  // For element-wise, they're the same
  %add = linalg.add ins(%0 : tensor<64x16xf32>, %1 : tensor<16x64xf32) 
                    outs(%buf : tensor<16x64xf32>) : tensor<16x64xf32>

  // This can be interpreted as a transpose: A^T x B -> C
  // If this was A x B^T, the resulting type would be <16x16xf32>
  %mm = linalg.matmul ins(%0 : tensor<16x64xf32>, %1 : tensor<16x64xf32)
                    outs(%buf : tensor<64x64xf32>) : tensor<64x64xf32>

But we have the same problem above, when the dimensions are identical:

  // If we want to transpose either A or B, there's no way to represent that in shapes
  %add = linalg.add ins(%0 : tensor<64x64xf32>, %1 : tensor<64x64xf32) 
                    outs(%buf : tensor<64x64xf32>) : tensor<64x64xf32>

  // Nor here...
  %mm = linalg.matmul ins(%0 : tensor<64x64xf32>, %1 : tensor<64x64xf32)
                    outs(%buf : tensor<64x64xf32>) : tensor<64x64xf32>

  // This is ambiguous
  %add = linalg.add ins(%0 : tensor<64x32x64xf32>, %1 : tensor<32x64x64xf32) 
                    outs(%buf : tensor<64x64x32xf32>) : tensor<64x64x32xf32>

Solutions

Currently, linalg.transpose has explicit permutations, which help.

So we could have a similar solution:

  %add = linalg.add ins(%0 : tensor<64x32x64xf32>, %1 : tensor<32x64x64xf32) 
                    outs(%buf : tensor<64x64x32xf32>)
                    permutation_a = [0, 2, 1], permutation_b = [1, 2, 0]
                    : tensor<64x64x32xf32>

Or in its more explicit affine map syntax:

  %add = linalg.add ins(%0 : tensor<64x32x64xf32>, %1 : tensor<32x64x64xf32) 
                    outs(%buf : tensor<64x64x32xf32>)
                    map_a = { (d0, d1, d2) -> (d0, d2, d1) },
                    map_b = { (d0, d1, d2) -> (d1, d2, d0) }
                    : tensor<64x64x32xf32>

Note: Case for affine maps

The two issues above have one solution in common: affine maps.

If we allow ops to have their affine maps changed by an optional attribute, we just need to change the parser/printer. Verification will still check that the affine maps is consistent, generalization will still use that affine map to lower to generics, etc.

Explicit compute type

Current linalg semantics of some ops is to cast the input types to the output type before performing the computation. This works when the output type is larger than the input, but not the other way around.

Encoding cast semantics would be too complex, but adding a compute_type simplifies it considerably.

The semantics becomes:

  • All input types cast to the compute type
  • Operation is performed on the compute type
  • Result is cast to the output type and returned

If the input types are the same as the compute type, then the cast is a no-op. Same for outputs. Dead code elimination gets rid of that, but it's also simple to check that on the generalization / lowering process.

The default behaviour of the compute_type is to be equals to the output type, so if the attribute is omitted, the behaviour does not change from current semantics.

Problem Statement

Current linalg implicit element type cast mandates input types to be cast to the output type before computation:

  // This named op:
  %0 = linalg.matmul ins(%a, %b : tensor<16x64xbf16>, tensor<64x32xf32>) outs(%buf : tensor<16x32xi8>) : tensor<16x32xi8>

  // Becomes:
  %0 = linalg.generic ...
       {
         ^bb0(%in1: bf16, %in2: f32, %out: i8):
           %cast_a = arith.fptosi %in1 : bf16 to i8
           %cast_b = arith.fptosi %in2 : f32 to i8
           %mul = arith.muli %cast_a, %cast_b : i8
           %acc = arith.addi %mul, %out : i8
           linalg.yield %acc : i8
       }

  // When what you probably want is:
  %0 = linalg.generic ...
       {
         ^bb0(%in1: bf16, %in2: f32, %out: i8):
           %cast_a = arith.extf %in1 : bf16 to f32
           %mul = arith.mulf %cast_a, %b : f32
           %acc = arith.addf %mul, %out : f32
           %cast = arith.fptosi %acc : f32 to i8
           linalg.yield %cast : i8
       }

Solutions

There are too many ways to describe the various casts between inputs and outputs, so we can perhaps simplify everything to a compute type.

The expected semantics is that all inputs are cast to the compute type and from the compute type to the output type:

  // This named op:
  %0 = linalg.matmul ins(%a, %b : tensor<16x64xbf16>, tensor<64x32xf32>) outs(%buf : tensor<16x32xi8>) compute_type = f32 : tensor<16x32xi8>

  // Becomes exactly the expected generic above:
  %0 = linalg.generic ...
       {
         ^bb0(%in1: bf16, %in2: f32, %out: i8):
           %cast_a = arith.extf %in1 : bf16 to f32
           // %b is already f32, we omit the cast
           %mul = arith.mulf %cast_a, %b : f32
           %acc = arith.addf %mul, %out : f32
           %cast = arith.fptosi %acc : f32 to i8
           linalg.yield %cast : i8
       }

This solution helps non-explicit quantization. To represent quantization, the compiler would have to create bubbles of compute on compute types surrounded by quantized types anyway, so the output type of the intermediate operations would not be quantized anyway.

  // First dequantize
  %0 = dequantize(%a) : tensor<...xi8> to tensor<...xf32>
  %1 = dequantize(%b) : tensor<...xi8> to tensor<...xf32>
  %2 = dequantize(%c) : tensor<...xi8> to tensor<...xf32>
  %3 = dequantize(%d) : tensor<...xi8> to tensor<...xf32>

  // "Bubble"
  %4 = op1(%a, %b) : tensor<...xf32>
  %5 = op1(%c, %4) : tensor<...xf32>
  %6 = op1(%d, %5) : tensor<...xf32>

  // Re-quantize
  %7 = quantize(%6) : tensor<...xf32> to tensor<...xi8>

However, the compute type can help the compiler to create such bubbles, by knowing what the hardware expects the accumulation type to be (ex. specific registers, ops). One could, for example, use that information to propagate the compute type across storage type boundaries and, using hardware-specific operations on higher-precision accumulators, construct those quantization groups.

Matmul transpose

OpDSL forced us to create multiple variants of matmuls, one per transpose:

Now we need to create transposed versions of batch_reduce_matmul, which have 3D inputs, so 3 transpose types per operand, resulting in 6 variations, without combining A and B transposes. This will not scale.

The solution is to use the explicit transpose mechanism above on matmuls. The issue is that those operations are heavily used in pattern matching, upstream and downstream.

There are four ways of solving this:

  1. Find a way to add transpose semantics to the existing matmul operation, slowly replace usage downstream and eventually close the transpose variants.
  2. Create an alternative matmul_transpose op that encodes the semantics of all types of transposes, match downsteam against, remove the existing transpose variants, make matmul_transpose to be like a matmul without attributes, remove matmul and rename matmul_transpose to just matmul. Do the same with the batch/reduce variants.
  3. Co-ordinate with downstream projects, assess the impact, and make a preemptive change off a branch (on both LLVM and project), then switch upstream and cascade downstream.
  4. Wait for the contract operation (below) to be active and useful, switch to contractions everywhere, then agree on retiring matmul variants to just aliases to contract.

The first one is the least problematic one, but it does expand opDSL usage. So, it should only be considered if we can find a way without having to change opDSL itself.

The second one is weird (using a matmul_transpose for simple matmuls) but is the shortest and safest path to success if the first does not work.

The third one can cause a lot of churn downstream and is best to avoid this one.

The forth one risks never completing. Designing a contract op is much more complex than matmul or even the convolution variants, who all have strict semantics. contract is as generic as it comes and needs to cater to all needs.

Long Term Issues

The issues below are important, but need to be handled with care and after the issues above have been fixed, or at least agreed and being worked on.

Grouping

Adding a grouping op can help heuristics searches with cost models to find optimal fusion opportunities for complex graphs. The search would use simple insert/extract/reorder transformations on an arbitrary nest of groupings until optimal tiling opportunities within groups and minimal communication across groups is achieved.

This work requires:

  • Strong semantics guarantees and flexibility on the operations, so we can create non-destructive rules to reorder or move across sub-groups. Named ops make this a much easier job.
  • A complete set of transforms that allows us to safely move operations inside and outside of groups, reordering the IR (while keeping SSA semantics), etc.
  • A robust cost model, taking advantage of an extensive target description, in which to analyse the state and return a value to the search.
  • An heuristic search algorithm, with set budgets and targets, which to drive the IR through the transformations and cost analysis, towards a local minimum.

Optional outcomes would be:

  • Early transformations can leave information in the target descriptor for later transformations to use, for the things the IR can't represent well.
  • Early transformations can "ask" later ones "what would you do if I did this?", in a way to have separate high/low-level cost models querying each other for more optimal code-gen.

Contract operation

This will have its own separate document, as it's a huge proposal, but here are the headlines:

  • The idea is to implement a contract operation that performs an Einstein summation (einsum).
  • Examples in NumPy and PyTorch.
  • Represents element-wise, inner/outer product, transposes, broadcasts, reductions. It's a catch-all operation for summation.
  • Not explicit on order of execution, especially as the number of dims increases, on multiple reductions, etc.
  • Complex einsums can be decomposed into a tree of simpler ones (which can be efficiently implemented in hardware), but finding an optimal tree is a very hard problem.

For linalg, we really should not aim for the whole semantics, and restrict ourselves to some core usages: Matrix/vector multiply and accumulate over non-standard dimensions. Standard matrix multiplications can still be represented by their own operations and make the life of pattern matchers easier.

We could eventually just alias matmul operations to the new contract operation. If the contract op is truly generic, then mapping all matmul operations to contractions should be a trivial exercise.

However, keeping the aliases for as long as possible allow pattern matchers to match against matmul and not its more complicated einsum representation.

The problem with pattern matchers and lowering semantics too early is the main reason why we propose we don't take the contract route off the bat, and focus on a matmul_transpose option for now. It's not clear how this task will fare long term. It may actually not be the right thing for the really simple operations like matmul.