Skip to content

Commit

Permalink
[GridSetPass] Fix handling of function with multiple results. (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt authored Jul 18, 2024
1 parent a2a8fe5 commit 32f4863
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
14 changes: 9 additions & 5 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,14 @@ class TTIRGridSet : public impl::TTIRGridSetBase<TTIRGridSet> {
.getChipDescs()[0]
.getGrid();
module_op->walk([&](func::FuncOp func) {
Type lastOpResultType;
SmallVector<Type> funcResultTypes;
func->walk([&](Operation *op) {
if (op->getNumResults() == 0) {
func::ReturnOp funcReturn = dyn_cast<func::ReturnOp>(op);
if (funcReturn) {
funcResultTypes.append(funcReturn.getOperandTypes().begin(),
funcReturn.getOperandTypes().end());
}
return;
}

Expand Down Expand Up @@ -756,15 +761,14 @@ class TTIRGridSet : public impl::TTIRGridSetBase<TTIRGridSet> {
GridAttr::get(&getContext(),
{grid_analysis_result.target_rows,
grid_analysis_result.target_columns}))));
lastOpResultType = op->getResult(0).getType();
});

// Update the function type to reflect the last operation's result type.
// Update the function type to reflect the updated return operation's
// result types.
//
FunctionType func_type = func.getFunctionType();
SmallVector<Type> newReturnTypes = {lastOpResultType};
FunctionType newFuncType = FunctionType::get(
func.getContext(), func_type.getInputs(), newReturnTypes);
func.getContext(), func_type.getInputs(), funcResultTypes);
func.setType(newFuncType);
});
}
Expand Down
27 changes: 27 additions & 0 deletions test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("test_ops.py:17_0_0":0:0)
module @pybuda_graph attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
%5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
// CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>
return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4)
} loc(#loc)
} loc(#loc)
#loc1 = loc("test_ops.py:17_0_0":0:4)
#loc2 = loc("test_ops.py:17_0_0":0:6)
#loc3 = loc("test_ops.py:17_0_0":0:3)
#loc4 = loc(unknown)
#loc5 = loc("add_1_0"(#loc1))
#loc6 = loc("add_2_0"(#loc2))
#loc7 = loc("add_0"(#loc3))

0 comments on commit 32f4863

Please sign in to comment.