Skip to content

Commit

Permalink
Add V0 optimizer tensor layout generation (#518)
Browse files Browse the repository at this point in the history
First version of op config generation. Includes only tensor layout for now.

* Generate all legal shard specs
* Add mock calls for legality checks which need to be replaced by TTNN interface.
  • Loading branch information
odjuricicTT authored Sep 6, 2024
1 parent 19ee116 commit 6eb09be
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 71 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct LegalGridAnalysisInput {
ChipDescAttr chipDesc;
GridAttr maxGrid;
RankedTensorType tensorType;
int64_t maxShardedGrids = 64;
llvm::StringMap<SmallVector<int64_t, 2>> *gridSizeOverrides;

LegalGridAnalysisInput()
Expand Down
147 changes: 139 additions & 8 deletions lib/Dialect/TTIR/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,48 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTIR/Analysis/LegalGridAnalysis.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

namespace mlir::tt::ttir {

bool mock_is_output_tensor_legal_for_op(Operation *op, LayoutAttr layout) {
// Placeholder, needs to be replaced with a call the the TTNN op interface.
return true;
}

bool tensor_shape_compatible_with_shard(Operation *op, LayoutAttr layout) {
// These constraints are implemented seperatelly in every TTNN op.
// Almost nothing seems to be shared between EVERY op, so is hard to have any
// logic here without the risk of discarding a valid configuraiton or modeling
// the constraint for each op. This logic may be offloaded to the TTNN op
// interface.

// For now we will check if the tilised tensor dims are divisible by the grid
// dims. This will definitly discard possible valid configurations, but is a
// start.
RankedTensorType tensorType =
mlir::cast<RankedTensorType>(op->getResult(0).getType());
llvm::ArrayRef<int64_t> tensorShape = tensorType.getShape();

int64_t MTiles = 1;
if (tensorType.getRank() >= 2) {
MTiles = (tensorShape.rbegin()[1] + 31) / 32;
}

int64_t KTIles = (tensorShape.back() + 31) / 32;

int64_t gridR = layout.getGrid().getShape()[0];
int64_t gridC = layout.getGrid().getShape()[1];

return (MTiles % gridR == 0) && (KTIles % gridC == 0);
}

bool LegalGridAnalysis::applyOverrides() {
// Lookup grid size overrides based on location information for current
// operation.
//

// TODO(odjuricic): Need to override all params, not just grid size.
RankedTensorType tensorType =
mlir::cast<RankedTensorType>(op->getResult(0).getType());
LayoutAttr layout = mlir::cast<LayoutAttr>(tensorType.getEncoding());
Expand All @@ -36,17 +70,114 @@ bool LegalGridAnalysis::applyOverrides() {
}

void LegalGridAnalysis::analysisImplementation() {
// Placeholder, needs to be implemented. Go through all the grid sizes and
// check if they are legal based on tensor type and device/chip attributes.
// For now result of analysis is maximum supported grid size.
//
// A first incomplete implementation of the LegalGridAnalysis.
// This implementation is a placeholder and is meant to just enable testing of
// other components.

// Process only TTIR ops.
if (not llvm::isa<TTIROp>(op)) {
return;
}
// Skip operations that don't have output tensors.
if (op->getNumResults() == 0) {
return;
}
if (llvm::isa<ToLayoutOp>(op)) {
return;
}

// Get output tensor type.
RankedTensorType tensorType =
mlir::cast<RankedTensorType>(op->getResult(0).getType());
LayoutAttr layout = mlir::cast<LayoutAttr>(tensorType.getEncoding());
llvm::ArrayRef<int64_t> tensorShape = tensorType.getShape();

analysisResult.push_back(layout.withGrid(
op->getContext(), tensorShape,
GridAttr::get(op->getContext(), analysisInput.maxGrid.getShape())));
// DRAM
// No grid is set since the tensor is not sharded.
// TODO(odjuricic): We need to set grid here since it will be used as the
// compute gird. (not implemented in runtime atm)
LayoutAttr dram =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceDRAM)
.withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved)
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(),
analysisInput.maxGrid.getShape()));
if (mock_is_output_tensor_legal_for_op(op, dram)) {
analysisResult.push_back(dram);
}

// L1 Interleaved (same as above).
LayoutAttr l1Interleaved =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1)
.withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved)
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(),
analysisInput.maxGrid.getShape()));
if (mock_is_output_tensor_legal_for_op(op, l1Interleaved)) {
analysisResult.push_back(l1Interleaved);
}

// L1 Sharded
LayoutAttr shardedBase =
layout.withMemorySpace(op->getContext(), MemorySpace::DeviceL1);
std::vector<LayoutAttr> shardedResults;

// Block Sharded
for (auto width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) {
for (auto height = 1; height <= analysisInput.maxGrid.getShape()[1];
++height) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {width, height}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::BlockSharded));
}
}

auto numCores =
analysisInput.maxGrid.getShape()[0] * analysisInput.maxGrid.getShape()[1];
// Height Sharded
// TODO(odjuricic): Missing affine mapping to actual grid. Need to check with
// runtime implementation on what to produce here.
for (auto height = 2; height <= numCores; ++height) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {height, 1}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::HeightSharded));
}

// Width Sharded
for (auto width = 2; width <= numCores; ++width) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {1, width}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::WidthSharded));
}

// Filter layouts based on output tensor legality for current op.
shardedResults.erase(
std::remove_if(shardedResults.begin(), shardedResults.end(),
[this](LayoutAttr layout) {
return !tensor_shape_compatible_with_shard(op, layout) ||
!mock_is_output_tensor_legal_for_op(op, layout);
}),
shardedResults.end());

// Pick top largest sharded grids.
std::sort(shardedResults.begin(), shardedResults.end(),
[](LayoutAttr a, LayoutAttr b) {
return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] >
b.getGrid().getShape()[0] * b.getGrid().getShape()[1];
});

analysisResult.insert(
analysisResult.end(), shardedResults.begin(),
shardedResults.begin() +
std::min(analysisInput.maxShardedGrids,
static_cast<int64_t>(shardedResults.size())));
}
} // namespace mlir::tt::ttir
4 changes: 3 additions & 1 deletion lib/Dialect/TTIR/Analysis/OptimalTargetGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ void OptimalTargetGridAnalysis::analysisImplementation() {
// Placeholder: pick the first legal grid.
//
for (auto opGrids : analysisInput.legalGrids) {
analysisResult[opGrids.first] = opGrids.second[0];
if (not opGrids.second.empty()) {
analysisResult[opGrids.first] = opGrids.second[0];
}
}
}
} // namespace mlir::tt::ttir
16 changes: 13 additions & 3 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "ttmlir/Dialect/TTIR/Analysis/OptimalTargetGridAnalysis.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Utils.h"
#include <mlir/Interfaces/DestinationStyleOpInterface.h>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRGENERICKERNEL
Expand Down Expand Up @@ -1090,9 +1091,18 @@ class TTIRGridSet : public impl::TTIRGridSetBase<TTIRGridSet> {

// Update the output layout attribute with the new grid size.
//
op->getResult(0).setType(RankedTensorType::get(
tensorShape, tensorType.getElementType(),
optimalTargetGridAnalysis.getResult().at(op)));
if (optimalTargetGridAnalysis.getResult().contains(op)) {
RankedTensorType newTensorType = RankedTensorType::get(
tensorShape, tensorType.getElementType(),
optimalTargetGridAnalysis.getResult().at(op));

op->getResult(0).setType(newTensorType);

if (llvm::isa<mlir::DestinationStyleOpInterface>(op)) {
// Update dps operand layout as well.
op->getOperands().back().setType(newTensorType);
}
}
});

// Update the function type to reflect the updated return operation's
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTIR/test_grid_set.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved>
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved>
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_2]]>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
Expand Down
10 changes: 5 additions & 5 deletions test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
#loc = loc("test_ops.py:17_0_0":0:0)
module attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
%5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
// CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>
// CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #layout>, tensor<1x32x32xf32, #layout>
return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4)
} loc(#loc)
} loc(#loc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
#loc = loc("test_ops.py:17_0_0":0:0)
module attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #[[LAYOUT_0:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #[[LAYOUT_0:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #system>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_3:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_3]]>
%5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
// CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #[[LAYOUT_0]]>, tensor<1x32x32xf32, #[[LAYOUT_0]]>
return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4)
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved>
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]>
// CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_2]]>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<64x128xf32>
Expand Down
Loading

0 comments on commit 6eb09be

Please sign in to comment.