-
Notifications
You must be signed in to change notification settings - Fork 31
Linalg matmul with affine maps
We have an explosion of matmul
variants in opDSL and we want to common them up with a single syntax for all three matmul variants.
Previous discussion has led to the fact that opDSL isn't powerful enough to represent this, so we have to move them to table-gen.
- To add syntax to
matmul
without changing any of the existing syntax expectations for current usage.matmul
is still justmatmul
. - To expose permutations and broadcast semantics on the three matmul variations:
matmul
,batch_matmul
andbatch_reduce_matmul
.
- To change how we use (semantics, syntax) these matmul variations, including their transposed versions. Upstream and downstream.
- To remove their transposed versions from the dialect now.
- To replace matmul with a generic contraction operation that is far too complex to design and agree upon in the same time frame.
Due to matmul
being in opDSL and we want to move to table-gen, we must remove the definition in opDSL to be able to add it to table-gen. But this is not a change, it's just code movement. The old syntax and semantics still needs to be 100% identical. The way to use the ops in C++ or IR cannot change in any way or form.
The syntax and semantics of the batch and reduce variants need to be identical to the base matmul op. We should discuss how it looks like on matmul first (because it's easier and separates concerns), and then extend to the other two.
Pushing the PR on one at a time or all three at the same time doesn't matter much, as long as we all agree on the syntax of the three variants in unison.
As discussed in the RFC, the document we wrote and again on the PR, we should use affine maps for both transpose and broadcast.
// No maps still _must_ continue to work
// This means (m, k) x (k, n) -> (m, n)
%0 = linalg.matmul
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// Fully exposed maps can still work, even if same semantics
%0 = linalg.matmul
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>, // m, k
affine_map<(d0, d1, d2) -> (d2, d1)>, // k, n
affine_map<(d0, d1, d2) -> (d0, d1)> // m, n
]
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// Fully exposed maps work on transposes, but require the rest to also be explicit,
// because `indexing_maps` is a list
%0 = linalg.matmul
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>, // m, k
affine_map<(d0, d1, d2) -> (d1, d2)>, // n, k (transposed B)
affine_map<(d0, d1, d2) -> (d0, d1)> // m, n
]
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// It also works when you have multiple cases
%0 = linalg.matmul
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0)>, // m, m (broadcast column)
affine_map<(d0, d1, d2) -> (d1, d2)>, // n, k (transposed B)
affine_map<(d0, d1, d2) -> (d0, d1)> // m, n
]
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// We could try to create a _"null map"_ that means "default behaviour". Still looks ugly.
%0 = linalg.matmul
indexing_maps = [
affine_map<>, // m, k (default behaviour)
affine_map<(d0, d1, d2) -> (d1, d2)>, // n, k (transposed B)
affine_map<> // m, n (default behaviour)
]
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// This could be made look "nicer" with named maps
#default = affine_map<>
#transpose_b = affine_map<(d0, d1, d2) -> (d1, d2)>
%0 = linalg.matmul
indexing_maps = [
#default, // m, k (default behaviour)
#transpose_b, // n, k (transposed B)
#default // m, n (default behaviour)
]
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// We could try naming the maps. Makes it easier to only have one "overriding" map.
%0 = linalg.matmul
map_b = affine_map<(d0, d1, d2) -> (d1, d2)> // n, k (transposed B)
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
// Is `a`, `b`, `c` really the best names here?
%0 = linalg.matmul
map_a = affine_map<(d0, d1, d2) -> (d0)>, // m, m (broadcast column)
map_b = affine_map<(d0, d1, d2) -> (d1, d2)> // n, k (transposed B)
ins(%a, %b: tensor<>, tensor<>)
outs (%c: tensor<>) : tensor<>
Batched matmuls have 3D inputs and outputs, which just means the extra dimension is a "batch" dimension and you just multiply the matrices one by one. Batch-Reduce matmuls do a further reduction in addition to the batch multiply, so while the inputs are still 3D, the output is a 2D "tile" of that accumulation.
In addition to the semantics of transpose and broadcast for matmul
, the batch version match_matmul
needs to agree on which one is the batch dimension of each operand.
Currently, the assumption is that the "outer" dimension is the batch dimension and the "transpose" variants only swap the inner dimensions.
// Standard batch matmul
%0 = linalg.batch_matmul
ins(%a, %b: tensor<BxMxK>, tensor<BxKxN>)
outs (%c: tensor<BxMxN>) : tensor<BxMxN>
// Transpose A batch matmul (note KxM on the first arg)
%0 = linalg.batch_matmul_transpose_a
ins(%a, %b: tensor<BxKxM>, tensor<BxKxN>)
outs (%c: tensor<BxMxN>) : tensor<BxMxN>
// Transpose B batch matmul (note NxK on the second arg)
%0 = linalg.batch_matmul_transpose_b
ins(%a, %b: tensor<BxMxK>, tensor<BxNxK>)
outs (%c: tensor<BxMxN>) : tensor<BxMxN>
This should be trivial to lower using affine maps:
// Transpose B batch matmul
%0 = linalg.batch_matmul
indexing_maps = [
affine_map<(dB, d0, d1, d2) -> (dB, d0, d2)>, // batch, m, k
affine_map<(dB, d0, d1, d2) -> (dB, d1, d2)>, // batch, n, k (transposed B)
affine_map<(dB, d0, d1, d2) -> (dB, d0, d1)> // batch, m, n
}
ins(%a, %b: tensor<BxMxK>, tensor<BxNxK>)
outs (%c: tensor<BxMxN>) : tensor<BxMxN>
But adding affine_maps may open the door for "arbitrary" contraction semantics, for example, making the inner dimension the "batch" dimension and the others the "matmul" dimensions. However, this would move the semantics of the batch_matmul
operations away into a more generic contract
operation, and that's a non goal.
So, we propose we keep the semantics of the affine map, and move to the verifier that the custom affine maps do not violate these assumptions, with errors emitted if the user-defined affine map does not match the expectations.
Requirements:
- The
dB
dim is always the batch dimension, the other three are the "matmul" dimensions. - Transpose can only occur on "matmul" dimensions.
- Broadcast can happen on both "batch" and "matmul" dimensions.
Valid transpose maps:
#transpose_a = affine_map<(dB, d0, d1, d2) -> (dB, d2, d0)>
#transpose_b = affine_map<(dB, d0, d1, d2) -> (dB, d1, d2)>
#transpose_c = affine_map<(dB, d0, d1, d2) -> (dB, d2, d1)>
Valid broadcast maps:
#broadcast_col_a = affine_map<(dB, d0, d1, d2) -> (dB, d0)>
#broadcast_row_a = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
#broadcast_batch_a = affine_map<(dB, d0, d1, d2) -> (d0, d2)>
#broadcast_col_b = affine_map<(dB, d0, d1, d2) -> (dB, d2)>
#broadcast_row_b = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
#broadcast_batch_b = affine_map<(dB, d0, d1, d2) -> (d2, d1)>
#broadcast_col_c = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
#broadcast_row_c = affine_map<(dB, d0, d1, d2) -> (dB, d0)>
#broadcast_batch_c = affine_map<(dB, d0, d1, d2) -> (d0, d1)>
Since transpose can only occur in "matmul" dimensions, a transpose+broadcast map is just the inverse broadcast map:
#broadcast_col_a = affine_map<(dB, d0, d1, d2) -> (dB, d0)>
#broadcast_row_a = affine_map<(dB, d0, d1, d2) -> (dB, d1)>
#transpose_broadcast_col_a = affine_map<(dB, d0, d1, d2) -> (dB, d1)> // Same as row_a
The input type must be appropriately shaped to accept that map (so possibly extended from <BxN>
to <Bx1xN>
).
The batch-reduce version is similar to the batch matmul, but with an additional reduction on the batch dimension. The same argument for restricting the affine maps to "known configurations" to avoid a general contraction semantics apply.
// Standard batch reduce matmul
%0 = linalg.batch_reduce_matmul
ins(%a, %b: tensor<BxMxK>, tensor<BxKxN>)
outs (%c: tensor<MxN>) : tensor<MxN>
// Transpose A batch reduce matmul (note KxM on the first arg)
%0 = linalg.batch_reduce_matmul_transpose_a // Not, this op does not exist, but extrapolated from the rest
ins(%a, %b: tensor<BxKxM>, tensor<BxKxN>)
outs (%c: tensor<MxN>) : tensor<MxN>
// Transpose B batch reduce matmul (note NxK on the second arg)
%0 = linalg.batch_reduce_matmul_transpose_b // Not, this op does not exist, but extrapolated from the rest
ins(%a, %b: tensor<BxMxK>, tensor<BxNxK>)
outs (%c: tensor<BxMxN>) : tensor<BxMxN>
The same restrictions for batch_matmul
apply, and in addition:
- There is no "batch broadcast" for the C argument, since it's a 2D shape.
With these changes applied, there will be three ways to represent a broadcast/transpose matmul
in linalg:
-
linalg.generic
with explicit maps (generic form) -
linalg.matmul
with explicit maps (explicit form) - Sequence of
linalg.broadcast
,linalg.transpose
, andlinalg.matmul
with implicit maps (DAG form)
We already have passes that "generalize" linalg operations from single named operations into their generic forms if the named ops represent a perfectly nested operation. These should continue to work with user-defined affine maps, except that the generic form maps will be the user defined ones. This is trivial, since for whatever syntax, the internal representation of the maps/attributes should be stored in a list of affine maps on construction.
However, what is missing is the generalization from a DAG form (ex. matmul(a, transpose(b), c)
). This requires more complex matchers which are doable, but not done yet. Even if operands come from arguments that, themselves, are consumers of transpose/broadcast operations, one can still propagate that through and pass the original (un-transposed/un-broadcasted) arguments and update the affine map.
There is current work to de-generalize linalg forms into named operations. These currently work on a small subset of operations and do not construct a DAG or use named operations with indexing maps (because they don't exits yet). Again, it should be trivial to generate the latter case for the alternatives that are supported (may need a long if/else chain, though).
Raising generics to a sequence of named operations, however, would need a matching structure that separates the semantics of affine maps and iterator types, to be able to use multiple named operations in a chain.
There is no such transform in MLIR today, but it would not be much different than the above proposals between generics and named forms.
From its DAG form to explicit form is most trivial, where the transforms knows what operations maps to which affine map and just updates it on the payload operation. The complexity comes when the DAG is not a chain (ex. users outside of the chain), or when there are more than one shape altering operations as the producers of operands (ex. matmul(transpose(broadcast(a)), transpose(b), broadcast(c))
).
From the explicit to DAG form is simple for the common cases where we know what the affine map represents (given the source operations / element-wise, matmul) and map to a particular external operation. But this would need a precise mapping between the semantics of the external operations (encoded in a language reference, not just defined by its implementation).
We should reuse as much as possible the C++ code that calculates and verifies the affine maps. If linalg.broadcast
uses some functions to validate its maps against some shapes, we need to use the same functions, which could have to be extracted into a common utility. Same for transpose.
In effect, calculating a chain of shape transformations should be simply the following sequence:
- Ask the operation what is the default map for that particular operand
- Use the aforementioned functions to apply the affine transformations to the map
- Persist that as the final version of that particular map
If we end up with a combination of explicit and DAG forms (ie. a transpose on top of a transposed matmul), we need to be able to reason about the semantics. The easiest way is to simply apply the same functions to the user-defined map and note that it is now in its default configuration.
A matmul
printer that notices the affine map is in a "default configuration" can omit its syntax, thus converting it to a simple matmul
without explicit maps.
This is a very straight-forward way of canonicalization by construction that does not need a special pass to do so. Just iterate until all operations are in their most compact representations. It could be a configuration parameter to list a number of operations that should not be folded (for ex. linalg.fill
) and thus finding a fixed point where they still exist when all the others have been converted to affine maps on named operations.
This would solve the "softmax problem" exposed in https://github.com/llvm/llvm-project/pull/97582.
There are three variants above that could work, we need to pick one and stick with it.
- Indexing maps are either fully off (default behaviour) or fully on (all maps explicit, even if all but one are changed).
- Indexing maps are still a list, but we have a "null element" that represents the default behaviour (purely syntactic sugar).
- Indexing maps are named and only those that are different from the default behaviour need to be explicit.
Any of the three solutions must still abide by the Goals and Non-Goals above.
Just add an optional indexing_maps
attribute to the operation.
Pros:
- Easiest to implement: implicit still keep the automatic maps, explicit replaces the whole list, just needs validation
- Looks like generics: syntax-wise, it's just like generics and we can use well named out-of-line maps to make it easier to read
- Best case for multiple custom maps: when most or all inputs have custom maps
Cons:
- Verbose when only one map is different: Needs to compute the whole list of all operands for all operation types
- Knowledge of "default" map is in the operation: changing one is easy, knowing the shape of all the others externally, may not be
- Still need to verify: some maps may be "valid" but do not represent the same operation anymore
Null element or any marker that means "use the default map for this operand".
Pros:
- Allows the operation to define its default maps: No need to teach external transforms to do that correctly
Cons:
- Syntax is confusing: Null element isn't quite the expected semantics for a default case, anything else adds complexity
- Parsing becomes harder: needs to interlace default arguments with custom ones and validate all that
- Position on the map list may be unclear: easy to miss if the element is first, second, etc.
Alternative: Crate a new list syntax where commas mean "empty argument" (ex. indexing_maps = [, affine_map<>, ]
has three elements.
One affine map per operand in a separate named attribute.
Pros:
- Terse syntax: especially when only one map is changed, but even when all are changed
- No need to know what the default map is for all operands in all arguments
Cons:
- Naming the attribute leads to bike-shedding:
map_a
?map_0
? - Parsing changes: parsing separate arguments then munching into a list, mixing default/custom behaviour
- Printing changes: needs to decompose indexing map into separate attributes (naming becomes a problem here)