-
Notifications
You must be signed in to change notification settings - Fork 31
Linalg Named Ops
Linalg generics can represent a vast number of linear algebra operations, from element wise to matmuls, reductions and broadcasts, etc.
However, understanding the properties of a generic op can be daunting for pattern matchers, as they must look at the affine maps, iterators, arguments, region scalar operations as a chain, not to mention the duality of the outs
parameter (shape or init) and its role in bufferization.
If all operations are represented as a sequence of generic ops, even simple patterns like add + relu
become a complex matching on internal properties of an op. Furthermore, generics can only represent perfectly-nested loops (the inner scalar region), so more complicated patterns (like multi-level batch-reduction GEMMs) become hard or even impossible.
This led to a number of named operations being introduced, for example broadcast
, reduce
, matmul
, batch_reduce_gemm
, etc. and the more recent additions of element-wise arithmetic and math operations like add
, exp
, etc. as well as the ML-specific softmax
.
Linalg operations can represent computation at various levels, from high-dimension (ex. tensor<64x32x8x128xf32>
) at ingress, to tile sizes (ex. tensor<32x32xf32>
) after tiling. Generic operations need to change their affine maps and iterators to adapt to the new tile semantics, but ultimately, they're the same operations as before, only tiled/blocked by outer parallel loops.
We need to be able to represent the high-level concepts from ingress dialects (ex. StableHLO) directly to named ops in Linalg at high dimensions. These representations can be ambiguous, for example on high-dim matmul
cases (batched, reducing). It's not always clear which are the parallel and which are the reduction dimensions on named ops.
In order to pack/unpack, tile or traverse inputs in non-unit stride fashion, these operations need some way to convey the information to the transformations and lowering passes. This can either be in a canonical form (where the dims are hard-coded in the op spec) or attributes (where the compiler needs to be told how to unpack the representation).
After tiling and fusing, these named ops may end up with parallel loops of the same named linalg operations on tiles (for example, element-wise), or a variation of the same ops and other ops that compose to the same semantics as the original high-dim ones (for example softmax).
These tile-sized named operations can be more efficiently vectorized (ex. rely on "oracle knowledge" based on the name of the op), or called into hand-written micro-kernels, or fused into larger kernels for offloading, without having to pattern match large swaths of generics and line up all of their affine maps and iterators.
Many MLIR users have that infrastructure downstream, using their own dialects, which means not only the operations don't share design, but also the passes cannot run on generic MLIR, which segments the space and duplicates efforts.
We are not, however, trying to replace existing dialects, upstream and downstream, with Linalg. We are trying to create a self-consistent linear algebra representation of common patterns (maths, ML, HPC) for efficient transformation and lowering into different low-level implementations which can be composed with each other.
By implementing this generic functionality in Linalg, we aim to common optimization passes, for CPU, GPU and accelerators, having different decisions based on cost models and heuristics and allowing downstream dialects (ingress and egress) to complete the picture, and even co-existing alongside Linalg.
To extract performance from hardware implementations, linear algebra matrix code needs to be tiled and their operations fused at the tile level to keep as much as possible of the operands in registers/cache.
This can be done today with linalg.generic
with some constraints:
- The affine maps and iterator types need to be identical
- The operations' region code needs to be at the same nest level
- The second generic loop needs to completely consume the first
This works well for element-wise operations, but as soon as your fusion needs to go in a different nest level (ex. matmul accumulate + bias add), linalg.generic
can't represent that anymore, and fusion cannot happen inside Linalg.
To work around that limitation, we can use scf.forall
for the parallel (batch) loops of the matmul with a linalg.generic
inside on a tile, and if you tile it right, you can represent a tiled reduction matmul that "looks like" a tile operation at the same nest level as the following element-wise (ex. bias add, relu), when fusion can happen.
However, a long series of generics will force the compiler to look at all affine maps, iterator types and use-def chains inside the regions across multiple scf.forall
loops to fuse. Unrecognized ops in between (ex. casts) can make it a lot harder.
So, not only you have to hope that the conversion to a tiled reduction matmul will lead to fusion, but hope that there won't be anything in between (from user code or compiler passes) that can destroy the patterns for fusion.
Given that linalg.generic
can represent a vast number of patterns, not all of them compatible with each other (strides, accumulation order, generic affine maps), it becomes really hard to even know if a non-standard affine map is compatible, at some particular nest level, with another tile element-wise map.
By creating a number of named operations that have strict semantics (but can still be converted to/from linalg.generic
), we simplify the analysis to only the most common cases.
Moreover, high-level operations have a more homogeneous semantics, so lowering them to smaller named ops with equivalent strict semantics follows the base mantra of MLIR of gradual lowering, keeping the semantics as high as possible for as long as possible.
They can still co-exist and be lowered to linalg.generic
, so there is no loss of function for the remaining compiler passes.
Unlike other dialects aimed at direct mapping to ML/HPC frameworks, Linalg named operations are focused on compiler transformations. Instead of mapping the frameworks' semantics, we can create a number of operations that represent that semantics and compose more broadly.
This avoids the need to create multiple variants of the same operation for different frameworks, and allows the compiler to define canonical representations for its pattern matchers.
While linalg can operate on sparse tensors (through the sparse_tensor
dialect as the compute payload of a linalg.generic
), it operates on dense tensors otherwise.
It is trivial to assume that all named ops also operate on dense tensors by default. Since there's no way to add sparse logic to named ops (no inner region), there should be a new round of design (lambda operators?) for sparse support, which is not part of this proposal.
Sparse tensor support can continue being a linalg.generic
feature until the time we have tht new design.
In the same way as above, because named ops don't have affine maps, they must also assume all operations have unit stride on all dimensions, unless explicitly annotated otherwise.
For example, element-wise and matmul-like operations don't have any stride/dilation attributes while the convolution-like operations do. Therefore, all named operations that don't have such attributes are assumed to have unit strides and no dilations.
Tensors are intentionally abstract and pose no problems to this property, but memrefs can have layout annotation (affine maps and stride information).
Like sparse representation, tiling a strided memref is non-trivial and should be left for generic representations, so named linalg operations should not accept non-unit strided memrefs.
Instead of trying to match broadcast semantics from particular frameworks, named operations should have no implicit broadcasts: every broadcast is explicit.
This means element-wise unary and binary operations have the same input and output type, rank, shape and element type.
A broadcasting add
on the second element is composed of two operations:
func.func bcast_add(%arg0 : tensor<8x16xf32>, %arg1 = tensor<16xf32>) -> tensor<8x16xf32> {
%0 = tensor.empty() : tensor<8x16xf32>
%bcast = linalg.broadcast ins(%arg1) outs(%0) : tensor<8x16xf32>
%1 = tensor.empty() : tensor<8x16xf32>
%sum = linalg.add ins(%arg0, %bcast) outs(%1) : tensor<8x16xf32>
return %sum
}
This code can be lowered in a single linalg.generic
that uses a broadcasting affine map ((d0, d1) -> (d0)
or (d0, d1) -> (d1)
depending on which dimension you want to cast).
This directly annotates which operand is the one being broadcasted and how.
Many of Linalg's named ops have implicit type casting on the inputs to the output type before performing the operation. This fixes numeric precision to specific patterns and doesn't allow representation of other patterns.
We can avoid this problem by having a linalg.typecast
operation (like broadcast
above), and prohibit named operations to have different input and output types.
For example:
// Linalg's auto-casting, `%arg1` is down casted before `add`
// Similar to `downcast_add` below.
func.func add(%arg0 : tensor<8x16xf32>,
%arg1 : tensor<8x16xf16>)
-> tensor<8x16xf16> {
%0 = tensor.empty() : tensor<8x16xf16>
%sum = linalg.add ins(%arg0, %arg1) outs(%1) : tensor<8x16xf16>
return %sum
}
// Downcast arg, accumulate in f16, return
// This has more rounding errors than the following one
func.func downcast_add(%arg0 : tensor<8x16xf32>
%arg1 : tensor<8x16xf16>)
-> tensor<8x16xf16> {
%0 = tensor.empty() : tensor<8x16xf16>
// DOWNCAST
%upcast = linalg.typecast ins(%arg0) outs(%0) : tensor<8x16xf16>
%1 = tensor.empty() : tensor<8x16xf16>
// ACCUMULATE (FP16)
%sum = linalg.add ins(%arg0, %upcast) outs(%1) : tensor<8x16xf16>
return %sum
}
// Upcast args, accumulates in f32, downcast output
// This has less rounding errors than the previous one
// But requires scratch storage
func.func upcast_add(%arg0 : tensor<8x16xf32>
%arg1 : tensor<8x16xf16>)
-> tensor<8x16xf16> {
%0 = tensor.empty() : tensor<8x16xf32>
// UPCAST
%upcast = linalg.typecast ins(%arg1) outs(%0) : tensor<8x16xf32>
%1 = tensor.empty() : tensor<8x16xf32>
// ACCUMULATE (FP32)
%sum = linalg.add ins(%arg0, %upcast) outs(%1) : tensor<8x16xf32>
%2 = tensor.empty() : tensor<8x16xf16>
// DOWNCAST
%downcast = linalg.typecast ins(%1) outs(%2) : tensor<8x16xf16>
return %downcast
}
Both functions can be lowered to a single linalg.generic
with a combination of arith.extf
and/or arith.truncf
.
For the initial proposal, all Linalg named operations operate on dense and unit-stride tensors and memrefs. The semantics of each operation depends on its type.
All element-wise operations have inputs and outputs of the same type (rank, dims and element type), and perform the named operation on each element with the following semantics:
#map = affine_map<(dn...) -> (dn...)>
%res = linalg.OP
{
indexing_maps = [#map, ...],
iterator_types = ["parallel", ...]
}
ins(%arg0, ... : tensor<eTy>/memref<eTy>) // N-ary op has N args
outs(%outs : tensor<eTy>/memref<eTy>) // Always one outs
{
^bb0(%in: eTy, ..., %out: eTy):
%0 = [OP] %in, ... : eTy // Scalar op (chain)
linalg.yield %0 : eTy
}
Above, OP
can be unary/binary, arithmetic/maths and can map directly to an existing scalar op in an existing scalar dialect (arith
, math
, etc.).
However, it can also be a composition of existing ops, as long as it's done element-wise. For example, a ReLU operation can be implemented as maxf(0,x)
or c = cmp(x,0) + sel(c,x,0)
, which have the same element-wise semantics and can both be implemented inside a linalg.generic
.
With strict type and shape cast semantics, we require that element-wise input and output shapes be identical and any cast to be represented as its own op.
This requirement is to avoid implicit assumptions (as outlined in the core semantics) but also to avoid combinatorial attributes (bcast_arg1
, typecast_arg2
, etc.).
With explicit casts, it becomes clear what implementation to follow, which operand/output is being casted to what. It's also trivial to match a sequence of ops (or to follow the use-def chain) to find out which other ops "annotate" another.
Examples:
// Plain simple add
%2 = linalg.add ins(%0, %1 : tensor<16x64xf32>, tensor<16x64xf32>)
outs(%buff : tensor<16x64xf32>)
// Broadcast + mul
%1 = linalg.broadcast ins(%cst : tensor<64xf32>)
inits(%buff : tensor<16x64xf32>)
%2 = linalg.mul ins(%0, %1 : tensor<16x64xf32>, tensor<16x64xf32>)
outs(%buff : tensor<16x64xf32>)
// Exponential + reduction
%1 = linalg.exp ins(%0 : tensor<16x64xf32>)
outs(%buff : tensor<16x64xf32>)
%2 = linalg.reduce { arith.addf }
ins(%1 : tensor<16x64xf32>)
outs(%buff : tensor<64xf32>)
dimensions = [1]
Unlike element-wise operations, cast ops must change the shape or element type of a type.
In fact, these should be the only operations that can do so, unless a named operation semantics requires it (ex. matmul shapes are MK x KN -> MN
).
As described above, cast operations are not always materialized. They can be used as annotation to a following operation or fused with a preceding one by changing its return value.
Broadcast operations expand the tensor in all dimensions that are unit-wide in the input type and not in the output type. Unit-wide is assumed on the outer dimensions when the rank of the input is lower than the output.
For example:
// Explicit dimension casts
%bcast0 = linalg.broadcast ins(%0 : tensor<1x64xf32>)
inits(%buf : tensor<16x64xf32>) :
tensor<16x64xf32>
%bcast1 = linalg.broadcast ins(%0 : tensor<16x1xf32>)
inits(%buf : tensor<16x64xf32>) :
tensor<16x64xf32>
// Implicit dimension casts (assume tensor<1x64xf32>)
%bcast2 = linalg.broadcast ins(%0 : tensor<64xf32>)
inits(%buf : tensor<16x64xf32>) :
tensor<16x64xf32>
// INVALID CAST (input ty should be tensor<16x1xf32>)
%bcast3 = linalg.broadcast ins(%0 : tensor<16xf32>)
inits(%buf : tensor<16x64xf32>) :
tensor<16x64xf32>
Reduction operations contract the tensor in the specified dimensions using a combining operation (ex. add
, mul
, max
, min
).
Unlike broadcast, reduction doesn't have to assume unit-wide dimensions, as the reduction dimensions are explicit in the operation's syntax. The output shape will have lower rank than the input type but its collapsed dimensions must be the same as the attribute.
For example:
// Row reduction into a column vector
%1 = linalg.reduce { arith.addf }
ins(%1 : tensor<16x64xf32>)
outs(%buff : tensor<16x1xf32>)
dimensions = [0]
// Column reduction into a row vector
%1 = linalg.reduce { arith.addf }
ins(%1 : tensor<16x64xf32>)
outs(%buff : tensor<64xf32>)
dimensions = [1]
// INVALID REDUCTION (wrong row dimension)
%1 = linalg.reduce { arith.addf }
ins(%1 : tensor<16x64xf32>)
outs(%buff : tensor<16xf32>)
dimensions = [1]
Changes the element type of the tensor and nothing else.
This is identical to a linalg.generic
with the scalar cast inside, element-wise.
Like broadcasts and reduction, this serves as an annotation to other operations, but can also be materialized directly as an element-wise cast if there are no surrounding operations to fuse into.
For example:
// Int-to-fp
%1 = linalg.typecast ins(%0 : tensor<128x8x64xi32>)
outs(%buff : tensor<128x8x64xf32>)
tensor<128x8x64xf32>
// Accumulator type > storage type
%1 = linalg.typecast ins(%0 : tensor<128x8x64xf16>)
outs(%buff : tensor<128x8x64xf32>)
tensor<128x8x64xf32>
%2 = linalg.mul ins(%0, %1 : tensor<128x8x64xf32>, tensor<128x8x64xf32>)
outs(%buff : tensor<128x8x64xf32>)
tensor<128x8x64xf32>
%1 = linalg.typecast ins(%0 : tensor<128x8x64xf32>)
outs(%stor : tensor<128x8x64xf16>)
tensor<128x8x64xf16>
// INVALID CAST (cannot change shape)
%1 = linalg.typecast ins(%0 : tensor<128x4x64xi32>)
outs(%buff : tensor<128x8x64xf32>)
tensor<128x8x64xf32>
On the typecast example above, in theory we have more operations than if type casts were implicit, but we also have more control in how we lower it.
For example, we can tile each of those ops separately and end up with three scf.forall
loops, one each.
This looks horrible, but we can now fuse them into a single scf.forall
, which at least guarantees the casts will be performed at a tile level, not whole tensor level.
But now we have an allocation of the tile on every loop, and the materialization of the casts, too.
There are a multiple ways to fix this problem:
- Hoist the allocation of the scratch tile memory out of the loop and reuse it as accumulator temp storage. This needs to take into account thread storage and shared memory concerns (for parallel loops).
- Replace the cast+op+cast code with a micro-kernel library call that has its own memory management, using a different accumulator type.
This can work across multiple ops in complex patterns.
A layer in a Multi Layer Perceptron inference with constant scaling bias can be seen as:
L(0) = Input
Weight[Layers] = PreTrainedConsts[ ... ]
Bias = 1.0
for (n in Layers)
L(n+1) = ReLU( GEMM(L(n), Weight) + Splat(Bias) )
A single layer can be represented in Linalg as:
func.func mlp_layer(%input : tensor<128x256xf32>,
%weight : tensor<256x256xf32>,
%bias : tensor<256xf32>) {
// Broadcasted bias
%bc_bias_buf = tensor.empty() : tensor<128x256xf32>
%bc_bias = linalg.broadcast ins(%bias : tensor<256xf32>)
outs(%bc_bias_buf : tensor<128x256xf32>)
// GEMM
%mm_buf = tensor.empty() : tensor<128x256xf32>
%mm = linalg.matmul ins(%input, %weight : tensor<128x256xf32>, tensor<256x256xf32>)
outs(%mm_buf : tensor<128x256xf32>)
// Bias Add (element-wise)
%add_buf = tensor.empty() : tensor<128x256xf32>
%add = linalg.add ins(%mm, %bc_bias : tensor<128x256xf32>, tensor<128x256xf32>)
outs(%add_buf : tensor<128x256xf32>)
// ReLU
%relu_buf = tensor.empty() : tensor<128x256xf32>
%relu = linalg.max ins(%add : tensor<128x256xf32>)
outs(%relu_buf : tensor<128x256xf32>)
// Return activation
return %relu
}
After tiling and fusion we can lower this to two scf.forall
parallel loops with tile operations inside.
The first one has the matmul, which has a reduction dimension, so it isn't at the same next level as the element-wise that follow.
The second, has both element-wise add
and max
.
However, how the compiler does it influences what needs to be materialized and what doesn't.
For example, hoisting the bias broadcast may seem like a good idea, since it's a constant that is reused for all loops:
func.func mlp_layer(%input : tensor<128x256xf32>,
%weight : tensor<256x256xf32>,
%bias : f32) {
// Outer matmul parallel loops
%res = scf.forall (...) {
...
}
// Broadcasted bias
%bc_bias_buf = tensor.empty() : tensor<128x256xf32>
%bc_bias = linalg.broadcast ins(%bias : tensor<256xf32>)
outs(%bc_bias_buf : tensor<128x256xf32>)
// Outer element-wise parallel loops
%res = scf.forall (...) {
...
}
// Return activation
return %res
}
If the add
operation knows that the input is a broadcast, it can change its own affine map to read from the same row for every column, avoiding allocations and memory movements.
If the broadcast is hoisted, to understand that relationship, the compiler needs to make sure all users of %bc_bias
can change how they read the inputs to be able to remove the hoisted operation.
It is much easier, however, to find patterns when they're local, and then hoist what's left. So let's assume we know the broadcast will be interpreted as annotation by the lowering pass, we end up with the following inner loop:
// First, the GEMM loop
%mm = scf.forall (...) {
%tile = linalg.matmul ...
}
// Now, the element-wise ops
// Assuming the "optimal" tile size is 32x32 for FP32
scf.forall (...) {
%biasT = tensor.extract_slice %bias[...] : tensor<32xf32>
%mmT = tensor.extract_slice %mm[...] : tensor<32x32xf32>
// Broadcasted bias
%bc_bias_buf = tensor.empty() : tensor<32x32xf32>
%bc_bias = linalg.broadcast ins(%biasT : tensor<32xf32>)
outs(%bc_bias_buf : tensor<32x32xf32>)
// Bias Add (element-wise)
%add_buf = tensor.empty() : tensor<32x32xf32>
%add = linalg.add ins(%mmT, %biasT : tensor<32x32xf32>, tensor<32x32xf32>)
outs(%add_buf : tensor<32x32xf32>)
// ReLU
%relu_buf = tensor.empty() : tensor<32x32xf32>
%relu = linalg.max ins(%add : tensor<32x32xf32>)
outs(%relu_buf : tensor<32x32xf32>)
// Parallel insert
scf.forall.in_parallel {
tensor.parallel_insert_slice %relu into ...
}
}
As is, this code already more efficient than the original one.
Both add
and max
are inside the same loop nest and form a chain that connects the tensor.extract_slice
with the tensor.parallel_insert_slice
and will be performed on the same tile while the data is still in registers / cache.
Because none of the named ops have affine maps or iterators, it is clear just by looking at the input and output of this chain that the operations are fused.
Because linalg.add
has one of its operands annotated with linalg.broadcast
, the compiler can either lower the code to a linalg.generic
with a respective affine map, or call a micro-kernel function that performs the same function.
The key difference here is that the transformation of the graph is done on traditional compiler operations, use-def chains and dense compute, while the lowering pass still has full information of the nuances to perform efficient code generation.
If your kernel library has support for a fused-batch-reducing matmul micro-kernel that not only performs a batch-reducing GEMM, but can also apply element-wise binary and unary ops to the results at the tile level.
In this case, the GEMM loop and the element-wise loop can fuse into one, and you can look at the inner loop as a sequence of "element-wise" calls writing to the same tile. So, while the inputs of the batch-reduced GEMM comes from other places (and the implementation can use registers wisely), the accumulation is on the same tile as the remaining operations will read/write.
So, in pseudo-MLIR:
loop {
%input_row = tensor.extract_slice %input[...]
%input_col = tensor.extract_slice %weight[...]
%buf = tensor.empty() : tensor<32x32xf32>
%mm = linalg.batch_reduce_matmul(%row, %col) -> %buf ... : tensor<32x32xf32>
%b = linalg.broadcast %bias ... : tensor<32x32xf32>
%add = linalg.add(%mm, %b) ... : tensor<32x32xf32>
%res = linalg.max(%add) ... : tensor<32x32xf32>
tensor.insert_slice %res into %output ... : tensor<32x32xf32>
}
Now here we have a batched gemm
+ add
+ relu
that can keep the accumulating 32x32
tile in registers for the whole inner loop.
The batch-reduced matmul will multiply 32x32
blocks from the block-row and block-column, accumulate on the tile buffer, which doesn't need to be allocated every loop, but once per thread and reused.
The bias add can directly read the same row of the bias vector over and over, and perhaps even keep it in registers (if small enough), and just add the bias to the scratch-pad after the matmul has finished that tile.
The ReLU is just a simple unary element-wise maxf
which can be dispatched in the same (32x32
) loop as the addf
and software pipelining can hide its cost well.
But, this becomes even easier if you have a micro-kernel library that implements that for you, at the tile level, using the best strategy for instruction sequence, registers and cache, and you can replace that entire loop with a single call (on CPUs):
loop {
%input_row = tensor.extract_slice %input[...]
%input_col = tensor.extract_slice %weight[...]
%res = mylib.fused_brgemm(%input_row, %input_col, %bias)
{ binary_function: arith.addf, binary_flags: col_bcast }
{ unary_function: arith.maxf, unary_flags: none } : tensor<32x32xf32>
tensor.insert_slice %res into %output ... : tensor<32x32xf32>
}
Given that the cast operations annotate a specific operand (to make sure we get the right cast), this can lead to patterns not being matched when searching for combined operations.
Most binary element-wise operations are commutative, so there could be a pass that rewrite operations with their "broadcasted" operands as the last one, for example.
But some operations are not commutative (ex. div
), so the pattern matcher still has to look past casts when trying to find use-def chain.
There could be a greedy matcher that knows the difference between actual producers (arith, maths, ML, etc) and casts (broadcast, typecast) to match variations to the same multi-op patterns.
Reduction is more complicated because it has additional semantics (which combiner operation was used, ex. add
or max
), which will change the pattern.
So we could have a few generic functions like:
// Helper for embeddings matcher
Value findProducerThroughCast<linalg::TypecastOp>(Value operand);
// Helper for fully-connected bias add
Value findProducerThroughCast<linalg::BroadcastOp>(Value operand);
// Helper for softmax
Value findProducerThroughCast<linalg::ReductionOp, arith::addf>(Value operand);