-
Notifications
You must be signed in to change notification settings - Fork 31
Tensor or Tensors
In order to allow for nested linalg.generic
operations, we need tensor of tensors, or tensor<N x M x tensor<I x J x f32>>
, with N
and M
the outer dimensions and I
and J
the inner dimensions.
This can be the same as tensor<N x M x I x J x f32>
, but it doesn't always have to be. Given that tensors don't have representational guarantees (ie. contiguous in memory), there's no requirement to the element types to be so (ex. encoding
). For example, you could have an outer sparse tensor of inner dense tensors or vice-versa. However, if all dimensions are dense, then they are equivalent. This will be important when we map to memref
of memref
s.
This example is motivated by the nested linalg example. Basically, we want to allow linalg.generic
s inside linalg.generic
s to help with non-perfectly nested fusion (ex. GEMM + EW).
This allows us to tile a linalg.generic
and still hold an outer linalg.generic
, which can be tiled and fused on its own, recursively. We can represent a lot of very complex patterns this way, but keeping the tiling and fusing logic simple, since they don't need to worry about the order of the outer dimensions or the dimensionality of the inner dimensions.
Furthermore, it allows us to distribute tiles and lanes and threads using nesting of layout encoding
s. For example, one could use an outer layout to distribute the sub-tensors across different nodes, the mid layout to distribute across NUMA regions and the inner (dense) layout to optimize for cache and register usage, per thread.
But to start the work, we need a simpler example, for instance, a perfectly nested linalg generic that is tiled into a nested linalg.generic
.
Original:
// GEMM maps
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
// 2D tensors
func.func @entry(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32> {
// Constant weight
%cst = arith.constant dense<1.000000e+00> : tensor<1024x1024xf32>
// Zero accumulation
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<256x1024xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
// GEMM on scalar elements
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<256x1024xf32>, tensor<1024x1024xf32>) outs(%1 : tensor<256x1024xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<256x1024xf32>
return %2 : tensor<256x1024xf32>
}
Packed:
// GEMM maps for 4D packed tensors
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
// 4D tensors
func.func @entry(%arg0: tensor<4x16x64x64xf32>) -> tensor<4x16x64x64xf32> {
// Constant 4D weight
%cst = arith.constant dense<1.000000e+00> : tensor<16x16x64x64xf32>
// Zero 4D accumulation
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x16x64x64xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4x16x64x64xf32>) -> tensor<4x16x64x64xf32>
// Perfectly nested 4D GEMM (ready for tiling)
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<4x16x64x64xf32>, tensor<16x16x64x64xf32>) outs(%1 : tensor<4x16x64x64xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<4x16x64x64xf32>
return %2 : tensor<4x16x64x64xf32>
}
Tensors of tensors:
// GEMM maps (note, exactly the same as 2D)
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
// 2D tensors (of tensors)
func.func @entry(%arg0: tensor<4x8x tensor<32x32xf32>>) -> tensor<4x8x tensor<32x32xf32>> {
// Constant weight "2D"
%cst = arith.constant dense<0.000000e+00> : tensor<8x8x tensor<32x32xf32>>
// Zero accumulation "2D"
%0 = tensor.empty() : tensor<4x8x tensor<32x32xf32>>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4x8x tensor<32x32xf32>>) -> tensor<4x8x tensor<32x32xf32>>
// GEMM on tensor elements (still "2D")
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<4x8x tensor<32x32xf32>>, tensor<8x8x tensor<32x32xf32>>) outs(%1 : tensor<4x8x tensor<32x32xf32>>) {
^bb0(%in: tensor<32x32xf32>, %in_1: tensor<32x32xf32>, %out: tensor<32x32xf32>):
// Tile GEMM 2D (equivalent to linagl.generic, but let's use matmul for simplicity)
%3 = linalg.matmul ins(%in, %in_1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%out: tensor<32x32xf32>) : tensor<32x32xf32>
linalg.yield %3 : tensor<32x32xf32>
} -> tensor<4x8x tensor<32x32xf32>>
return %2 : tensor<4x8x tensor<32x32xf32>>
}
}
The layout expectations are exactly the same as before: there isn't any.
With encoding
specifications (or future distributed / sparse attributes), one can change the expectation of a tensor.
However, on a tensor of scalars, the encoding is all-or-nothing, ie. the entire tensor must adopt the same encoding.
This leads to compilers splitting the tensor into multiple smaller tensors, while only the compiler knows they're actually sub-tensors of the original.
Once the IR is produced, a following compiler would have lost that semantics.
With tensor of tensors, however, we can not only keep that information in IR, but we can create compiler passes that can use it for fusion. For example, two unrelated computations can only be fused after tiling if their iteration semantics is identical. If tiling produces loops, the iteration semantics can only be recovered from induction variable scalar evolution analysis, which may not be complete.
With linalg.generic
applying regions to sub-tensors, the traversal semantics is encoded in affine maps and tensor attributes and can be more easily manipulated even if the generic pass doesn't understand the underlying encoding.
If the semantics are strongly held at the construction and verification for each encoding, then generic passes can assume two encodings are the same and therefore can be fused together.
Furthermore, having those representations explicit could allow the compiler to reorder sub-tensors to match following operations, mark operations for fusions at different nesting levels, etc.
The big question here is: what does it mean to have an element type of tensor<32x32xf32>
?
In theory, tensors do NOT have a pre-defined layout, so representing a tile as a 32 by 32 tensor is perfectly fine. In practice, tiles are not contiguous in memory, and bufferization may assume so if it doesn't know that this is an element type.
This document assumes reading and writing to an element type of tensor type is equivalent to tensor.extract_slice
and tensor.insert_slice
.
While bufferization converts them into memref.subview
with stride information, tensor does not have such notion.
This following example is not invalid, but it is undefined:
func.func @some_random_func (%arg0 : tensor<4x8x32x32xf32>) { // Let's assume this is contiguous
// This can also be assumed to be a dense contiguous 32x32 tensor
%buf = tensor.empty() : tensor<32x32xf32>
// This is most certainly not contiguous in memory
%slice = tensor.extract_slice %arg0[%c1, %c2, 0, 0][1, 1, 32, 32][1, 1, 1, 1] : tensor<32x32xf32>
// At the tensor level, how does `@some_func` know he layout of the tensor?
%0 = call @some_func(%buf) // Dense
%1 = call @some_func(%slice) // Strided
}
When mapping tensor of tensors to memrefs there are two options:
- Implement
memref of memrefs
- Full encoding-aware lowering
The first case is just a delayed version of the second, since we'll have to fully lower eventually. But it allows the compiler to continue running fusion passes while already in memref level. However, this also means memref operations need to adapt to understand the semantics of non-scalar element type.
Since we'll have to implement full lowering at some point for this work to be complete, it would be quicker to adopt the second strategy first and only implement memref of memrefs
if and when it becomes necessary.
Tensors of vectors can already be used, but there are two main hindrances:
- Vectors have a restricted semantics (dense, cache/register location, register semantics)
- It can only have scalar element types
While the second restriction isn't quite a problem, the first one is, as it would not allow us to have sparse or distributed semantics, or recursive nesting.
However, tensor of vectors have a clear advantage: Scalar operations (arith
/math
) should work on them unchanged, so the nesting is easier to achieve.
This could be an initial implementation, to get an idea of what we need before we start changing the semantics of tensors.
As alluded above, sparsity is a concern when nesting tensors.
However, for now, sparse tensors are restricted to the sparse_tensor
dialect and should not be changed because of this proposal.
If the sparse community thinks it could benefit from nesting, we can start another thread/work in that direction.
There is work on distributed data structures (Ramba, C++ distributed arrays, BCL) that would benefit from distributed tensor semantics. Many HPC applications are divided across tiles and halos with cross-dependent semantics, across different sub-graphs of the execution path.
The biggest problem with distribution is that the tile size is not always the same across the same encompassing tensor.
The most common example is when the tensor dimensions aren't evenly divided by the number of nodes, so some nodes can get a slightly larger slice.
This means you can't have a tensor<4x8 x tensor<32x32xf32>>
because some of them will be tensor<32x32xf32>
while others could be tensor<31x31xf32>
.
One way to work around it is to create a layout description where the actual shape of the element type is defined by an equation:
%tiled = call @shard(%dense : tensor<128x255xf32>) : tensor<8x4 x tensor<32x(...div...mod...)xf32>
Being able to annotate such representations would allow compiler passes to directly operate on the concepts themselves on tensors rather than splitting into sub-tensors, outer loops and low level control flow.
But this is a complex subject for separate discussion.
TODO
WIP: https://github.com/rengolin/llvm-project/tree/tensor-of-tensors
- Recursive attribute parsing
- Dense splat is the simple case
- Dense non-splat needs to assume contiguous (
<MxNxIxJ> == <MxNx<IxJ>>
) - Forbid
encoding
, sparse, etc. for now - Final type must be outer type
-
linalg.fill
to have the same semantics as constant attributes above -
linalg.yield
to understand non-scalar types versuslinalg.generic
final return type -
linalg.generic
to allow non-scalar ops inside the region