diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 7024984d9..75471f6e4 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -724,9 +724,14 @@ class TTIRGridSet : public impl::TTIRGridSetBase { .getChipDescs()[0] .getGrid(); module_op->walk([&](func::FuncOp func) { - Type lastOpResultType; + SmallVector funcResultTypes; func->walk([&](Operation *op) { if (op->getNumResults() == 0) { + func::ReturnOp funcReturn = dyn_cast(op); + if (funcReturn) { + funcResultTypes.append(funcReturn.getOperandTypes().begin(), + funcReturn.getOperandTypes().end()); + } return; } @@ -756,15 +761,14 @@ class TTIRGridSet : public impl::TTIRGridSetBase { GridAttr::get(&getContext(), {grid_analysis_result.target_rows, grid_analysis_result.target_columns})))); - lastOpResultType = op->getResult(0).getType(); }); - // Update the function type to reflect the last operation's result type. + // Update the function type to reflect the updated return operation's + // result types. // FunctionType func_type = func.getFunctionType(); - SmallVector newReturnTypes = {lastOpResultType}; FunctionType newFuncType = FunctionType::get( - func.getContext(), func_type.getInputs(), newReturnTypes); + func.getContext(), func_type.getInputs(), funcResultTypes); func.setType(newFuncType); }); } diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir new file mode 100644 index 000000000..de9d2d9ea --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir @@ -0,0 +1,27 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +#loc = loc("test_ops.py:17_0_0":0:0) +module @pybuda_graph attributes {tt.system_desc = #tt.system_desc<[{arch = , grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [], [<0, 0, 0, 0>]>} { + func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { + // CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>> + // CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> + %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) + // CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1> + return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) + } loc(#loc) +} loc(#loc) +#loc1 = loc("test_ops.py:17_0_0":0:4) +#loc2 = loc("test_ops.py:17_0_0":0:6) +#loc3 = loc("test_ops.py:17_0_0":0:3) +#loc4 = loc(unknown) +#loc5 = loc("add_1_0"(#loc1)) +#loc6 = loc("add_2_0"(#loc2)) +#loc7 = loc("add_0"(#loc3))