Skip to content

Tensor or Tensors

Renato Golin edited this page Nov 15, 2023 · 7 revisions

In order to allow for nested linalg.generic operations, we need tensor of tensors, or tensor<N x M x tensor<I x J x f32>>, with N and M the outer dimensions and I and J the inner dimensions.

This can be the same as tensor<N x M x I x J x f32>, but it doesn't always have to be. Given that tensors don't have representational guarantees (ie. contiguous in memory), there's no requirement to the element types to be so (ex. encoding). For example, you could have an outer sparse tensor of inner dense tensors or vice-versa. However, if all dimensions are dense, then they are equivalent. This will be important when we map to memref of memrefs.

Motivation

This example is motivated by the nested linalg example. Basically, we want to allow linalg.generics inside linalg.generics to help with non-perfectly nested fusion (ex. GEMM + EW).

This allows us to tile a linalg.generic and still hold an outer linalg.generic, which can be tiled and fused on its own, recursively. We can represent a lot of very complex patterns this way, but keeping the tiling and fusing logic simple, since they don't need to worry about the order of the outer dimensions or the dimensionality of the inner dimensions.

Furthermore, it allows us to distribute tiles and lanes and threads using nesting of layout encodings. For example, one could use an outer layout to distribute the sub-tensors across different nodes, the mid layout to distribute across NUMA regions and the inner (dense) layout to optimize for cache and register usage, per thread.

But to start the work, we need a simpler example, for instance, a perfectly nested linalg generic that is tiled into a nested linalg.generic.

Original:

  // GEMM maps
  #map = affine_map<(d0, d1, d2) -> (d0, d2)>
  #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
  #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

  // 2D tensors
  func.func @entry(%arg0: tensor<256x1024xf32>) -> tensor<256x1024xf32> {

    // Constant weight
    %cst = arith.constant dense<1.000000e+00> : tensor<1024x1024xf32>

    // Zero accumulation
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<256x1024xf32>
    %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>

    // GEMM on scalar elements
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<256x1024xf32>, tensor<1024x1024xf32>) outs(%1 : tensor<256x1024xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<256x1024xf32>

    return %2 : tensor<256x1024xf32>
  }

Packed:

  // GEMM maps for 4D packed tensors
  #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
  #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
  #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>

  // 4D tensors
  func.func @entry(%arg0: tensor<4x16x64x64xf32>) -> tensor<4x16x64x64xf32> {

    // Constant 4D weight
    %cst = arith.constant dense<1.000000e+00> : tensor<16x16x64x64xf32>

    // Zero 4D accumulation
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<4x16x64x64xf32>
    %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4x16x64x64xf32>) -> tensor<4x16x64x64xf32>

    // Perfectly nested 4D GEMM (ready for tiling)
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %cst : tensor<4x16x64x64xf32>, tensor<16x16x64x64xf32>) outs(%1 : tensor<4x16x64x64xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<4x16x64x64xf32>

    return %2 : tensor<4x16x64x64xf32>
  }

Tensors of tensors:

  // GEMM maps (note, exactly the same as 2D)
  #map = affine_map<(d0, d1, d2) -> (d0, d2)>
  #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
  #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

  // 2D tensors (of tensors)
  func.func @entry(%arg0: tensor<4x8x tensor<32x32xf32>>) -> tensor<4x8x tensor<32x32xf32>> {
    // Constant weight "2D"
    %cst = arith.constant dense<0.000000e+00> : tensor<8x8x tensor<32x32xf32>>

    // Zero accumulation "2D"
    %0 = tensor.empty() : tensor<4x8x tensor<32x32xf32>>
    %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4x8x tensor<32x32xf32>>) -> tensor<4x8x tensor<32x32xf32>>

    // GEMM on tensor elements (still "2D")
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<4x8x tensor<32x32xf32>>, tensor<8x8x tensor<32x32xf32>>) outs(%1 : tensor<4x8x tensor<32x32xf32>>) {
    ^bb0(%in: tensor<32x32xf32>, %in_1: tensor<32x32xf32>, %out: tensor<32x32xf32>):

      // Tile GEMM 2D (equivalent to linagl.generic, but let's use matmul for simplicity)
      %3 = linalg.matmul ins(%in, %in_1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%out: tensor<32x32xf32>) : tensor<32x32xf32>
      linalg.yield %3 : tensor<32x32xf32>
    } -> tensor<4x8x tensor<32x32xf32>>

    return %2 : tensor<4x8x tensor<32x32xf32>>
  }
}

Semantics

Layout Expectations

The layout expectations are exactly the same as before: there isn't any.

With encoding specifications (or future distributed / sparse attributes), one can change the expectation of a tensor. However, on a tensor of scalars, the encoding is all-or-nothing, ie. the entire tensor must adopt the same encoding. This leads to compilers splitting the tensor into multiple smaller tensors, while only the compiler knows they're actually sub-tensors of the original. Once the IR is produced, a following compiler would have lost that semantics.

With tensor of tensors, however, we can not only keep that information in IR, but we can create compiler passes that can use it for fusion. For example, two unrelated computations can only be fused after tiling if their iteration semantics is identical. If tiling produces loops, the iteration semantics can only be recovered from induction variable scalar evolution analysis, which may not be complete.

With linalg.generic applying regions to sub-tensors, the traversal semantics is encoded in affine maps and tensor attributes and can be more easily manipulated even if the generic pass doesn't understand the underlying encoding. If the semantics are strongly held at the construction and verification for each encoding, then generic passes can assume two encodings are the same and therefore can be fused together.

Furthermore, having those representations explicit could allow the compiler to reorder sub-tensors to match following operations, mark operations for fusions at different nesting levels, etc.

Extracting tensor element types

The big question here is: what does it mean to have an element type of tensor<32x32xf32>?

In theory, tensors do NOT have a pre-defined layout, so representing a tile as a 32 by 32 tensor is perfectly fine. In practice, tiles are not contiguous in memory, and bufferization may assume so if it doesn't know that this is an element type.

This document assumes reading and writing to an element type of tensor type is equivalent to tensor.extract_slice and tensor.insert_slice. While bufferization converts them into memref.subview with stride information, tensor does not have such notion.

This following example is not invalid, but it is undefined:

func.func @some_random_func (%arg0 : tensor<4x8x32x32xf32>) { // Let's assume this is contiguous
  // This can also be assumed to be a dense contiguous 32x32 tensor
  %buf = tensor.empty() : tensor<32x32xf32>

  // This is most certainly not contiguous in memory
  %slice = tensor.extract_slice %arg0[%c1, %c2, 0, 0][1, 1, 32, 32][1, 1, 1, 1] : tensor<32x32xf32>

  // At the tensor level, how does `@some_func` know he layout of the tensor?
  %0 = call @some_func(%buf)   // Dense
  %1 = call @some_func(%slice) // Strided
}

Mapping to Memrefs

When mapping tensor of tensors to memrefs there are two options:

  1. Implement memref of memrefs
  2. Full encoding-aware lowering

The first case is just a delayed version of the second, since we'll have to fully lower eventually. But it allows the compiler to continue running fusion passes while already in memref level. However, this also means memref operations need to adapt to understand the semantics of non-scalar element type.

Since we'll have to implement full lowering at some point for this work to be complete, it would be quicker to adopt the second strategy first and only implement memref of memrefs if and when it becomes necessary.

Versus Tensors of Vectors

Tensors of vectors can already be used, but there are two main hindrances:

  1. Vectors have a restricted semantics (dense, cache/register location, register semantics)
  2. It can only have scalar element types

While the second restriction isn't quite a problem, the first one is, as it would not allow us to have sparse or distributed semantics, or recursive nesting.

However, tensor of vectors have a clear advantage: Scalar operations (arith/math) should work on them unchanged, so the nesting is easier to achieve. This could be an initial implementation, to get an idea of what we need before we start changing the semantics of tensors.

Sparse Tensor

As alluded above, sparsity is a concern when nesting tensors. However, for now, sparse tensors are restricted to the sparse_tensor dialect and should not be changed because of this proposal. If the sparse community thinks it could benefit from nesting, we can start another thread/work in that direction.

Distributed Tensors

There is work on distributed data structures (Ramba, C++ distributed arrays, BCL) that would benefit from distributed tensor semantics. Many HPC applications are divided across tiles and halos with cross-dependent semantics, across different sub-graphs of the execution path.

The biggest problem with distribution is that the tile size is not always the same across the same encompassing tensor. The most common example is when the tensor dimensions aren't evenly divided by the number of nodes, so some nodes can get a slightly larger slice. This means you can't have a tensor<4x8 x tensor<32x32xf32>> because some of them will be tensor<32x32xf32> while others could be tensor<31x31xf32>.

One way to work around it is to create a layout description where the actual shape of the element type is defined by an equation:

  %tiled = call @shard(%dense : tensor<128x255xf32>) : tensor<8x4 x tensor<32x(...div...mod...)xf32>

Being able to annotate such representations would allow compiler passes to directly operate on the concepts themselves on tensors rather than splitting into sub-tensors, outer loops and low level control flow.

But this is a complex subject for separate discussion.

Tiling a Tensor of Tensor

TODO

Steps to have tensors as a valid element type

WIP: https://github.com/rengolin/llvm-project/tree/tensor-of-tensors

Constants

  • Recursive attribute parsing
    • Dense splat is the simple case
    • Dense non-splat needs to assume contiguous (<MxNxIxJ> == <MxNx<IxJ>>)
    • Forbid encoding, sparse, etc. for now
    • Final type must be outer type

Linalg operations

  • linalg.fill to have the same semantics as constant attributes above
  • linalg.yield to understand non-scalar types versus linalg.generic final return type
  • linalg.generic to allow non-scalar ops inside the region

More to come...