Skip to content

Linalg matmul with affine maps

Renato Golin edited this page Aug 22, 2024 · 16 revisions

We have an explosion of matmul variants in opDSL and we want to common them up with a single syntax for all three matmul variants.

Previous discussion has led to the fact that opDSL isn't powerful enough to represent this, so we have to move them to table-gen.

Goals

  1. To add syntax to matmul without changing any of the existing syntax expectations for current usage. matmul is still just matmul.
  2. To expose permutations and broadcast semantics on the three matmul variations: matmul, batch_matmul and batch_reduce_matmul.

Non-Goals

  1. To change how we use (semantics, syntax) these matmul variations, including their transposed versions. Upstream and downstream.
  2. To remove their transposed versions from the dialect now.
  3. To replace matmul with a generic contraction operation that is far too complex to design and agree upon in the same time frame.

Practicals

OpDSL vs. Table-Gen

Due to matmul being in opDSL and we want to move to table-gen, we must remove the definition in opDSL to be able to add it to table-gen. But this is not a change, it's just code movement. The old syntax and semantics still needs to be 100% identical. The way to use the ops in C++ or IR cannot change in any way or form.

Batch and Reduce Variants

The syntax and semantics of the batch and reduce variants need to be identical to the base matmul op. We should discuss how it looks like on matmul first (because it's easier and separates concerns), and then extend to the other two.

Pushing the PR on one at a time or all three at the same time doesn't matter much, as long as we all agree on the syntax of the three variants in unison.

Syntax

As discussed in the RFC, the document we wrote and again on the PR, we should use affine maps for both transpose and broadcast.

Baseline

  // No maps still _must_ continue to work
  // This means (m, k) x (k, n) -> (m, n)
  %0 = linalg.matmul
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

Indexing Maps List

  // Fully exposed maps can still work, even if same semantics
  %0 = linalg.matmul
              indexing_maps = [
                       affine_map<(d0, d1, d2) -> (d0, d2)>, // m, k
                       affine_map<(d0, d1, d2) -> (d2, d1)>, // k, n
                       affine_map<(d0, d1, d2) -> (d0, d1)>  // m, n
              ]
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

  // Fully exposed maps work on transposes, but require the rest to also be explicit,
  // because `indexing_maps` is a list
  %0 = linalg.matmul
              indexing_maps = [ 
                       affine_map<(d0, d1, d2) -> (d0, d2)>, // m, k
                       affine_map<(d0, d1, d2) -> (d1, d2)>, // n, k (transposed B)
                       affine_map<(d0, d1, d2) -> (d0, d1)>  // m, n
              ]
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

  // It also works when you have multiple cases
  %0 = linalg.matmul
              indexing_maps = [ 
                       affine_map<(d0, d1, d2) -> (d0)>,     // m, m (broadcast column)
                       affine_map<(d0, d1, d2) -> (d1, d2)>, // n, k (transposed B)
                       affine_map<(d0, d1, d2) -> (d0, d1)>  // m, n
              ]
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

Null Map Idea

  // We could try to create a _"null map"_ that means "default behaviour". Still looks ugly.
  %0 = linalg.matmul
              indexing_maps = [ 
                       affine_map<>,                         // m, k (default behaviour)
                       affine_map<(d0, d1, d2) -> (d1, d2)>, // n, k (transposed B)
                       affine_map<>                          // m, n (default behaviour)
              ]
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

  // This could be made look "nicer" with named maps
  #default = affine_map<>
  #transpose_b = affine_map<(d0, d1, d2) -> (d1, d2)>
  %0 = linalg.matmul
              indexing_maps = [ 
                       #default,     // m, k (default behaviour)
                       #transpose_b, // n, k (transposed B)
                       #default      // m, n (default behaviour)
              ]
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

Named Map Idea

  // We could try naming the maps. Makes it easier to only have one "overriding" map.
  %0 = linalg.matmul
              map_b = affine_map<(d0, d1, d2) -> (d1, d2)> // n, k (transposed B)
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

  // Is `a`, `b`, `c` really the best names here?
  %0 = linalg.matmul
              map_a = affine_map<(d0, d1, d2) -> (d0)>,    // m, m (broadcast column)
              map_b = affine_map<(d0, d1, d2) -> (d1, d2)> // n, k (transposed B)
              ins(%a, %b: tensor<>, tensor<>)
              outs (%c: tensor<>) : tensor<>

Batch / Reduce

Batched matmuls have 3D inputs and outputs, which just means the extra dimension is a "batch" dimension and you just multiply the matrices one by one. Batch-Reduce matmuls do a further reduction in addition to the batch multiply, so while the inputs are still 3D, the output is a 2D "tile" of that accumulation.

Batch Matmul

In addition to the semantics of transpose and broadcast for matmul, the batch version match_matmul needs to agree on which one is the batch dimension of each operand.

Currently, the assumption is that the "outer" dimension is the batch dimension and the "transpose" variants only swap the inner dimensions.

  // Standard batch matmul
  %0 = linalg.batch_matmul
              ins(%a, %b: tensor<BxMxK>, tensor<BxKxN>)
              outs (%c: tensor<BxMxN>) : tensor<BxMxN>

  // Transpose A batch matmul (note KxM on the first arg)
  %0 = linalg.batch_matmul_transpose_a
              ins(%a, %b: tensor<BxKxM>, tensor<BxKxN>)
              outs (%c: tensor<BxMxN>) : tensor<BxMxN>


  // Transpose B batch matmul (note NxK on the second arg)
  %0 = linalg.batch_matmul_transpose_b
              ins(%a, %b: tensor<BxMxK>, tensor<BxNxK>)
              outs (%c: tensor<BxMxN>) : tensor<BxMxN>

This should be trivial to lower using affine maps:

  // Transpose B batch matmul
  %0 = linalg.batch_matmul
              indexing_maps = [ 
                       affine_map<(dB, d0, d1, d2) -> (dB, d0, d2)>, // batch, m, k
                       affine_map<(dB, d0, d1, d2) -> (dB, d1, d2)>, // batch, n, k (transposed B)
                       affine_map<(dB, d0, d1, d2) -> (dB, d0, d1)>  // batch, m, n
              }
              ins(%a, %b: tensor<BxMxK>, tensor<BxNxK>)
              outs (%c: tensor<BxMxN>) : tensor<BxMxN>

But adding affine_maps may open the door for "arbitrary" contraction semantics, for example, making the inner dimension the "batch" dimension and the others the "matmul" dimensions. However, this would move the semantics of the batch_matmul operations away into a more generic contract operation, and that's a non goal.

So, we propose we keep the semantics of the affine map, and move to the verifier that the custom affine maps do not violate these assumptions, with errors emitted if the user-defined affine map does not match the expectations.

Requirements:

  • The dB dim is always the batch dimension, the other three are the "matmul" dimensions.
  • Transpose can only occur on "matmul" dimensions.
  • Broadcast can happen on both "batch" and "matmul" dimensions.

Valid transpose maps:

  #transpose_a = affine_map<(dB, d0, d1, d2) -> (dB, d2, d0)>
  #transpose_b = affine_map<(dB, d0, d1, d2) -> (dB, d1, d2)>
  #transpose_c = affine_map<(dB, d0, d1, d2) -> (dB, d2, d1)>

Valid broadcast maps:

  #broadcast_col_a = affine_map<(dB, d0, d1, d2) -> (dB, d0)>
  #broadcast_row_a = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
  #broadcast_batch_a = affine_map<(dB, d0, d1, d2) -> (d0, d2)>

  #broadcast_col_b = affine_map<(dB, d0, d1, d2) -> (dB, d2)>
  #broadcast_row_b = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
  #broadcast_batch_b = affine_map<(dB, d0, d1, d2) -> (d2, d1)>

  #broadcast_col_c = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
  #broadcast_row_c = affine_map<(dB, d0, d1, d2) -> (dB, d0)>
  #broadcast_batch_c = affine_map<(dB, d0, d1, d2) -> (d0, d1)>

Since transpose can only occur in "matmul" dimensions, a transpose+broadcast map is just the inverse broadcast map:

  #broadcast_col_a = affine_map<(dB, d0, d1, d2) -> (dB, d0)>
  #broadcast_row_a = affine_map<(dB, d0, d1, d2) -> (dB, d1)>

  #transpose_broadcast_col_a = affine_map<(dB, d0, d1, d2) -> (dB, d1)> // Same as row_a

The input type must be appropriately shaped to accept that map (so possibly extended from <BxN> to <Bx1xN>).

Batch-Reduce Matmul

The batch-reduce version is similar to the batch matmul, but with an additional reduction on the batch dimension. The same argument for restricting the affine maps to "known configurations" to avoid a general contraction semantics apply.

  // Standard batch reduce matmul
  %0 = linalg.batch_reduce_matmul
              ins(%a, %b: tensor<BxMxK>, tensor<BxKxN>)
              outs (%c: tensor<MxN>) : tensor<MxN>

  // Transpose A batch reduce matmul (note KxM on the first arg)
  %0 = linalg.batch_reduce_matmul_transpose_a // Not, this op does not exist, but extrapolated from the rest
              ins(%a, %b: tensor<BxKxM>, tensor<BxKxN>)
              outs (%c: tensor<MxN>) : tensor<MxN>


  // Transpose B batch reduce matmul (note NxK on the second arg)
  %0 = linalg.batch_reduce_matmul_transpose_b // Not, this op does not exist, but extrapolated from the rest
              ins(%a, %b: tensor<BxMxK>, tensor<BxNxK>)
              outs (%c: tensor<BxMxN>) : tensor<BxMxN>

The same restrictions for batch_matmul apply, and in addition:

  • There is no "batch broadcast" for the C argument, since it's a 2D shape.

Form Conversion

With these changes applied, there will be three ways to represent a broadcast/transpose matmul in linalg:

  1. linalg.generic with explicit maps (generic form)
  2. linalg.matmul with explicit maps (explicit form)
  3. Sequence of linalg.broadcast, linalg.transpose, and linalg.matmul with implicit maps (DAG form)

Generalization

We already have passes that "generalize" linalg operations from single named operations into their generic forms if the named ops represent a perfectly nested operation. These should continue to work with user-defined affine maps, except that the generic form maps will be the user defined ones. This is trivial, since for whatever syntax, the internal representation of the maps/attributes should be stored in a list of affine maps on construction.

However, what is missing is the generalization from a DAG form (ex. matmul(a, transpose(b), c)). This requires more complex matchers which are doable, but not done yet. Even if operands come from arguments that, themselves, are consumers of transpose/broadcast operations, one can still propagate that through and pass the original (un-transposed/un-broadcasted) arguments and update the affine map.

De-generalization

There is current work to de-generalize linalg forms into named operations. These currently work on a small subset of operations and do not construct a DAG or use named operations with indexing maps (because they don't exits yet). Again, it should be trivial to generate the latter case for the alternatives that are supported (may need a long if/else chain, though).

Raising generics to a sequence of named operations, however, would need a matching structure that separates the semantics of affine maps and iterator types, to be able to use multiple named operations in a chain.

DAG to Explicit

There is no such transform in MLIR today, but it would not be much different than the above proposals between generics and named forms.

From its DAG form to explicit form is most trivial, where the transforms knows what operations maps to which affine map and just updates it on the payload operation. The complexity comes when the DAG is not a chain (ex. users outside of the chain), or when there are more than one shape altering operations as the producers of operands (ex. matmul(transpose(broadcast(a)), transpose(b), broadcast(c))).

From the explicit to DAG form is simple for the common cases where we know what the affine map represents (given the source operations / element-wise, matmul) and map to a particular external operation. But this would need a precise mapping between the semantics of the external operations (encoded in a language reference, not just defined by its implementation).

Implementation Details

We should reuse as much as possible the C++ code that calculates and verifies the affine maps. If linalg.broadcast uses some functions to validate its maps against some shapes, we need to use the same functions, which could have to be extracted into a common utility. Same for transpose.

In effect, calculating a chain of shape transformations should be simply the following sequence:

  1. Ask the operation what is the default map for that particular operand
  2. Use the aforementioned functions to apply the affine transformations to the map
  3. Persist that as the final version of that particular map

If we end up with a combination of explicit and DAG forms (ie. a transpose on top of a transposed matmul), we need to be able to reason about the semantics. The easiest way is to simply apply the same functions to the user-defined map and note that it is now in its default configuration.

A matmul printer that notices the affine map is in a "default configuration" can omit its syntax, thus converting it to a simple matmul without explicit maps.

This is a very straight-forward way of canonicalization by construction that does not need a special pass to do so. Just iterate until all operations are in their most compact representations. It could be a configuration parameter to list a number of operations that should not be folded (for ex. linalg.fill) and thus finding a fixed point where they still exist when all the others have been converted to affine maps on named operations.

This would solve the "softmax problem" exposed in https://github.com/llvm/llvm-project/pull/97582.

Syntax Discussion

There are three variants above that could work, we need to pick one and stick with it.

  1. Indexing maps are either fully off (default behaviour) or fully on (all maps explicit, even if all but one are changed).
  2. Indexing maps are still a list, but we have a "null element" that represents the default behaviour (purely syntactic sugar).
  3. Indexing maps are named and only those that are different from the default behaviour need to be explicit.

Any of the three solutions must still abide by the Goals and Non-Goals above.

Full list

Just add an optional indexing_maps attribute to the operation.

Pros:

  • Easiest to implement: implicit still keep the automatic maps, explicit replaces the whole list, just needs validation
  • Looks like generics: syntax-wise, it's just like generics and we can use well named out-of-line maps to make it easier to read
  • Best case for multiple custom maps: when most or all inputs have custom maps

Cons:

  • Verbose when only one map is different: Needs to compute the whole list of all operands for all operation types
  • Knowledge of "default" map is in the operation: changing one is easy, knowing the shape of all the others externally, may not be
  • Still need to verify: some maps may be "valid" but do not represent the same operation anymore

Default Element

Null element or any marker that means "use the default map for this operand".

Pros:

  • Allows the operation to define its default maps: No need to teach external transforms to do that correctly

Cons:

  • Syntax is confusing: Null element isn't quite the expected semantics for a default case, anything else adds complexity
  • Parsing becomes harder: needs to interlace default arguments with custom ones and validate all that
  • Position on the map list may be unclear: easy to miss if the element is first, second, etc.

Alternative: Crate a new list syntax where commas mean "empty argument" (ex. indexing_maps = [, affine_map<>, ] has three elements.

Named Maps

One affine map per operand in a separate named attribute.

Pros:

  • Terse syntax: especially when only one map is changed, but even when all are changed
  • No need to know what the default map is for all operands in all arguments

Cons:

  • Naming the attribute leads to bike-shedding: map_a? map_0?
  • Parsing changes: parsing separate arguments then munching into a list, mixing default/custom behaviour
  • Printing changes: needs to decompose indexing map into separate attributes (naming becomes a problem here)