Skip to content

Commit

Permalink
[Optimizer] Fixup grid set pass. Enable tests. Enable pass in TTNN pi…
Browse files Browse the repository at this point in the history
…peline. (#134)
  • Loading branch information
nobradovictt authored Jul 15, 2024
1 parent 02ee9fd commit 04f952e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
16 changes: 7 additions & 9 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,6 @@ class TTIRGridSet : public impl::TTIRGridSetBase<TTIRGridSet> {
// Currently a placeholder pass for grid size optimization.
// Goes through all the operations and sets the grid size to max supported
// by target chip. Lacks:
// - Proper update of layout attributes related to grid size.
// - Constraint checking, whether the grid size is supported by the current
// OP based on inputs and op type.
//
Expand Down Expand Up @@ -742,19 +741,18 @@ class TTIRGridSet : public impl::TTIRGridSetBase<TTIRGridSet> {
//
GridAnalysisResult grid_analysis_result = grid_analysis.getResult();

LayoutAttr layout = op->getResult(0)
.getType()
.template cast<RankedTensorType>()
.getEncoding()
.template cast<LayoutAttr>();
RankedTensorType tensor_type =
op->getResult(0).getType().template cast<RankedTensorType>();
LayoutAttr layout =
tensor_type.getEncoding().template cast<LayoutAttr>();
llvm::ArrayRef<int64_t> tensor_shape = tensor_type.getShape();

// Update the output layout attribute with the new grid size.
//
auto resultTy = op->getResult(0).getType().cast<RankedTensorType>();
op->getResult(0).setType(RankedTensorType::get(
resultTy.getShape(), resultTy.getElementType(),
tensor_shape, tensor_type.getElementType(),
layout.withGrid(
&getContext(), resultTy,
&getContext(), tensor_shape,
GridAttr::get(&getContext(),
{grid_analysis_result.target_rows,
grid_analysis_result.target_columns}))));
Expand Down
4 changes: 1 addition & 3 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ class ConvertTTIRToTTNN

void createTTIRToTTNNBackendPipeline(OpPassManager &pm) {
pm.addPass(mlir::tt::ttir::createTTIRLayout());
// Not ready to be enabled by default.
//
// pm.addPass(mlir::tt::ttir::createTTIRGridSet());
pm.addPass(mlir::tt::ttir::createTTIRGridSet());
pm.addPass(createTTNNOpenDevice());
pm.addPass(createConvertTTIRToTTNN());
}
Expand Down
3 changes: 1 addition & 2 deletions test/ttmlir/Dialect/TTIR/test_grid_set.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
// RUN: ttmlir-opt --ttir-layout --ttir-grid-set %s | FileCheck %s
// UNSUPPORTED: true
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {torch.debug_module_name = "_lambda", tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: #layout2 = #tt.layout<8192x128x1, undef, <8x8>, memref<64x128xf32, #l1_>>
// CHECK: #layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>>
// CHECK: %[[C:.*]] = "ttir.layout"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.layout"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #layout2>
Expand Down

0 comments on commit 04f952e

Please sign in to comment.