From f39d4d254ec5300ee30243990552f3908ff725f1 Mon Sep 17 00:00:00 2001 From: Nikola Obradovic Date: Mon, 22 Jul 2024 08:41:59 +0000 Subject: [PATCH] [Optimizer] Per op grid overrides. --- .../Dialect/TTIR/Analysis/GridAnalysis.h | 16 +++-- .../Dialect/TTIR/Analysis/TTIRAnalysis.h | 15 ++++- include/ttmlir/Dialect/TTIR/Passes.td | 6 ++ include/ttmlir/Dialect/TTNN/Passes.h | 66 +++++++++++++++++++ lib/Dialect/TTIR/Analysis/GridAnalysis.cpp | 19 ++++++ lib/Dialect/TTIR/Transforms/Passes.cpp | 5 +- lib/Dialect/TTNN/Transforms/Passes.cpp | 5 +- .../multiple_add_with_loc_grid_override.mlir | 28 ++++++++ 8 files changed, 152 insertions(+), 8 deletions(-) create mode 100644 test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir diff --git a/include/ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h b/include/ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h index fbc61a9614..6a306d4de7 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h +++ b/include/ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h @@ -6,6 +6,7 @@ #define TTMLIR_DIALECT_TTIR_ANALYSIS_GRIDANALYSIS_H #include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h" +#include "llvm/ADT/StringMap.h" namespace mlir::tt::ttir { @@ -17,16 +18,22 @@ struct GridAnalysisResult { struct GridAnalysisInput { int max_supported_rows; int max_supported_columns; + llvm::StringMap> *grid_size_overrides; - GridAnalysisInput() : max_supported_rows(1), max_supported_columns(1) {} + GridAnalysisInput() + : max_supported_rows(1), max_supported_columns(1), + grid_size_overrides(nullptr) {} - GridAnalysisInput(int max_supported_rows, int max_supported_columns) + GridAnalysisInput(int max_supported_rows, int max_supported_columns, + llvm::StringMap> *grid_size_overrides) : max_supported_rows(max_supported_rows), - max_supported_columns(max_supported_columns) {} + max_supported_columns(max_supported_columns), + grid_size_overrides(grid_size_overrides) {} bool operator==(const GridAnalysisInput &rhs) const { return max_supported_rows == rhs.max_supported_rows && - max_supported_columns == rhs.max_supported_columns; + max_supported_columns == rhs.max_supported_columns && + grid_size_overrides == rhs.grid_size_overrides; } bool operator!=(const GridAnalysisInput &rhs) const { @@ -41,6 +48,7 @@ class GridAnalysis private: void analysisImplementation() override; + bool applyOverrides() override; public: GridAnalysis(Operation *op) : TTIRAnalysis(op) {} diff --git a/include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h b/include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h index 32f39b804e..95180f935b 100644 --- a/include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h +++ b/include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h @@ -24,6 +24,12 @@ template class TTIRAnalysis { // virtual void analysisImplementation() = 0; + // Load overrides if they exist. + // Must be implemented by every analysis type. + // Returns true if analysis should be skipped. + // + virtual bool applyOverrides() = 0; + public: virtual ~TTIRAnalysis() {}; @@ -52,7 +58,14 @@ template class TTIRAnalysis { // Skip the analysis if it was already run and input params haven't changed. // if (!is_valid) { - analysisImplementation(); + // Apply overrides if needed. + // + bool skip_analysis = applyOverrides(); + + if (!skip_analysis) { + analysisImplementation(); + } + is_valid = true; } } diff --git a/include/ttmlir/Dialect/TTIR/Passes.td b/include/ttmlir/Dialect/TTIR/Passes.td index 357d38d9cd..987ce57c16 100644 --- a/include/ttmlir/Dialect/TTIR/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Passes.td @@ -48,6 +48,12 @@ def TTIRGridSet: Pass<"ttir-grid-set", "::mlir::ModuleOp"> { Go through the ops, set grid size for each op based on grid analysis, by updating layout attribute of each op. }]; + let options = [ + Option<"overrideGridSizes", "override-grid-sizes", + "llvm::StringMap>", + /*default=*/"llvm::StringMap>()", + "Override grid sizes for specific ops.">, + ]; } #endif diff --git a/include/ttmlir/Dialect/TTNN/Passes.h b/include/ttmlir/Dialect/TTNN/Passes.h index b7e6089357..f80ed457ba 100644 --- a/include/ttmlir/Dialect/TTNN/Passes.h +++ b/include/ttmlir/Dialect/TTNN/Passes.h @@ -18,6 +18,58 @@ namespace mlir::tt::ttnn { #define GEN_PASS_REGISTRATION #include "ttmlir/Dialect/TTNN/Passes.h.inc" +struct GridSizeOverrideParser + : public llvm::cl::parser>> { +public: + GridSizeOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap> &value) { + SmallVector overrideList; + constexpr size_t kvPairSize = 2; + constexpr size_t kMaxGridSize = 2; + constexpr size_t iOpName = 0; + constexpr size_t iGrid = 1; + arg.split(overrideList, ','); + for (const StringRef override : overrideList) { + SmallVector kv; + override.split(kv, '='); + if (kv.size() != kvPairSize) { + opt.error("Invalid format for override grid sizes: " + override); + return true; + } + SmallVector grid; + SmallVector gridParts; + kv[iGrid].split(gridParts, 'x'); + for (const StringRef gridPart : gridParts) { + int gridValue; + if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { + opt.error("Invalid grid size: " + gridPart); + return true; + } + grid.push_back(gridValue); + } + value[kv[iOpName]] = grid; + } + return false; + } + + static void print(llvm::raw_ostream &os, + const llvm::StringMap> &value) { + os << "override-grid-sizes="; + size_t count = 0; + for (const auto &entry : value) { + os << entry.getKey() << "="; + os << entry.getValue()[0] << "x" << entry.getValue()[1]; + if (++count < value.size()) { + os << ","; + } + } + os << "\n"; + } +}; + // Options for the TTIR to TTNN backend pipeline. // struct TTIRToTTNNBackendPipelineOptions @@ -30,6 +82,20 @@ struct TTIRToTTNNBackendPipelineOptions *this, "enable-grid-set", llvm::cl::desc("Determine and set max valid grid for Op execution."), llvm::cl::init(true)}; + + // Option to override grid size for specific ops. + // The format is a comma separated list of op names and grid sizes. + // + // Example: "op1=2x2,op2=4x4" + // + // This will set the grid size for op1 to 2x2 and op2 to 4x4. + // + // Note: This option is only valid if gridSetPassEnabled is true. + // + Option>, GridSizeOverrideParser> + overrideGridSizes{*this, "override-grid-sizes", + llvm::cl::desc("Override grid sizes for specific ops."), + llvm::cl::init(llvm::StringMap>())}; }; void createTTIRToTTNNBackendPipeline( diff --git a/lib/Dialect/TTIR/Analysis/GridAnalysis.cpp b/lib/Dialect/TTIR/Analysis/GridAnalysis.cpp index 7ecbf05510..bdde174a49 100644 --- a/lib/Dialect/TTIR/Analysis/GridAnalysis.cpp +++ b/lib/Dialect/TTIR/Analysis/GridAnalysis.cpp @@ -5,6 +5,25 @@ #include "ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h" namespace mlir::tt::ttir { + +bool GridAnalysis::applyOverrides() { + // Lookup grid size overrides based on location information for current + // operation. + // + if (analysis_input.grid_size_overrides && op->getLoc().isa()) { + StringRef loc_str_op_name = op->getLoc().cast().getName(); + auto grid_override = + analysis_input.grid_size_overrides->find(loc_str_op_name); + if (grid_override != analysis_input.grid_size_overrides->end()) { + analysis_result.target_rows = grid_override->second[0]; + analysis_result.target_columns = grid_override->second[1]; + return true; + } + } + + return false; +} + void GridAnalysis::analysisImplementation() { // Placeholder. For now result of analysis is maximum supported grid size. // diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 662b018b69..9f10b0afd5 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -748,8 +748,9 @@ class TTIRGridSet : public impl::TTIRGridSetBase { // Initialize the grid analysis with the max grid size. // - grid_analysis.init( - GridAnalysisInput(max_grid.getShape()[0], max_grid.getShape()[1])); + grid_analysis.init(GridAnalysisInput(max_grid.getShape()[0], + max_grid.getShape()[1], + &overrideGridSizes)); // Run the grid analysis and get the result. // diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 358a025f1c..9f31777f77 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -181,8 +181,11 @@ class ConvertTTIRToTTNN void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { pm.addPass(mlir::tt::ttir::createTTIRLayout()); + if (options.gridSetPassEnabled) { - pm.addPass(mlir::tt::ttir::createTTIRGridSet()); + ttir::TTIRGridSetOptions gridSetOptions; + gridSetOptions.overrideGridSizes = options.overrideGridSizes; + pm.addPass(mlir::tt::ttir::createTTIRGridSet(gridSetOptions)); } pm.addPass(createTTNNOpenDevice()); diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir new file mode 100644 index 0000000000..ae262370c9 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir @@ -0,0 +1,28 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="override-grid-sizes=add_1_0=4x4,add_2_0=4x4" %s | FileCheck %s +#any_device = #tt.operand_constraint +#loc = loc("test_ops.py:17_0_0":0:0) +module @pybuda_graph attributes {tt.system_desc = #tt.system_desc<[{arch = , grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [], [<0, 0, 0, 0>]>} { + func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { + // CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>> + // CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>> + // CHECK: #layout3 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> + %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout3> + %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) + // CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1> + return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) + } loc(#loc) +} loc(#loc) +#loc1 = loc("test_ops.py:17_0_0":0:4) +#loc2 = loc("test_ops.py:17_0_0":0:6) +#loc3 = loc("test_ops.py:17_0_0":0:3) +#loc4 = loc(unknown) +#loc5 = loc("add_1_0"(#loc1)) +#loc6 = loc("add_2_0"(#loc2)) +#loc7 = loc("add_0"(#loc3))