Skip to content

Commit

Permalink
[Optimizer] Per op grid overrides.
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt committed Jul 22, 2024
1 parent 25bb6ae commit f39d4d2
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 8 deletions.
16 changes: 12 additions & 4 deletions include/ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_DIALECT_TTIR_ANALYSIS_GRIDANALYSIS_H

#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h"
#include "llvm/ADT/StringMap.h"

namespace mlir::tt::ttir {

Expand All @@ -17,16 +18,22 @@ struct GridAnalysisResult {
struct GridAnalysisInput {
int max_supported_rows;
int max_supported_columns;
llvm::StringMap<SmallVector<int, 2>> *grid_size_overrides;

GridAnalysisInput() : max_supported_rows(1), max_supported_columns(1) {}
GridAnalysisInput()
: max_supported_rows(1), max_supported_columns(1),
grid_size_overrides(nullptr) {}

GridAnalysisInput(int max_supported_rows, int max_supported_columns)
GridAnalysisInput(int max_supported_rows, int max_supported_columns,
llvm::StringMap<SmallVector<int, 2>> *grid_size_overrides)
: max_supported_rows(max_supported_rows),
max_supported_columns(max_supported_columns) {}
max_supported_columns(max_supported_columns),
grid_size_overrides(grid_size_overrides) {}

bool operator==(const GridAnalysisInput &rhs) const {
return max_supported_rows == rhs.max_supported_rows &&
max_supported_columns == rhs.max_supported_columns;
max_supported_columns == rhs.max_supported_columns &&
grid_size_overrides == rhs.grid_size_overrides;
}

bool operator!=(const GridAnalysisInput &rhs) const {
Expand All @@ -41,6 +48,7 @@ class GridAnalysis

private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
GridAnalysis(Operation *op) : TTIRAnalysis(op) {}
Expand Down
15 changes: 14 additions & 1 deletion include/ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ template <class I, class R> class TTIRAnalysis {
//
virtual void analysisImplementation() = 0;

// Load overrides if they exist.
// Must be implemented by every analysis type.
// Returns true if analysis should be skipped.
//
virtual bool applyOverrides() = 0;

public:
virtual ~TTIRAnalysis() {};

Expand Down Expand Up @@ -52,7 +58,14 @@ template <class I, class R> class TTIRAnalysis {
// Skip the analysis if it was already run and input params haven't changed.
//
if (!is_valid) {
analysisImplementation();
// Apply overrides if needed.
//
bool skip_analysis = applyOverrides();

if (!skip_analysis) {
analysisImplementation();
}

is_valid = true;
}
}
Expand Down
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTIR/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def TTIRGridSet: Pass<"ttir-grid-set", "::mlir::ModuleOp"> {
Go through the ops, set grid size for each op based on grid analysis,
by updating layout attribute of each op.
}];
let options = [
Option<"overrideGridSizes", "override-grid-sizes",
"llvm::StringMap<SmallVector<int, 2>>",
/*default=*/"llvm::StringMap<SmallVector<int, 2>>()",
"Override grid sizes for specific ops.">,
];
}

#endif
66 changes: 66 additions & 0 deletions include/ttmlir/Dialect/TTNN/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,58 @@ namespace mlir::tt::ttnn {
#define GEN_PASS_REGISTRATION
#include "ttmlir/Dialect/TTNN/Passes.h.inc"

struct GridSizeOverrideParser
: public llvm::cl::parser<llvm::StringMap<SmallVector<int, 2>>> {
public:
GridSizeOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<SmallVector<int, 2>>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<SmallVector<int, 2>> &value) {
SmallVector<StringRef> overrideList;
constexpr size_t kvPairSize = 2;
constexpr size_t kMaxGridSize = 2;
constexpr size_t iOpName = 0;
constexpr size_t iGrid = 1;
arg.split(overrideList, ',');
for (const StringRef override : overrideList) {
SmallVector<StringRef, kvPairSize> kv;
override.split(kv, '=');
if (kv.size() != kvPairSize) {
opt.error("Invalid format for override grid sizes: " + override);
return true;
}
SmallVector<int, kMaxGridSize> grid;
SmallVector<StringRef, kMaxGridSize> gridParts;
kv[iGrid].split(gridParts, 'x');
for (const StringRef gridPart : gridParts) {
int gridValue;
if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) {
opt.error("Invalid grid size: " + gridPart);
return true;
}
grid.push_back(gridValue);
}
value[kv[iOpName]] = grid;
}
return false;
}

static void print(llvm::raw_ostream &os,
const llvm::StringMap<SmallVector<int, 2>> &value) {
os << "override-grid-sizes=";
size_t count = 0;
for (const auto &entry : value) {
os << entry.getKey() << "=";
os << entry.getValue()[0] << "x" << entry.getValue()[1];
if (++count < value.size()) {
os << ",";
}
}
os << "\n";
}
};

// Options for the TTIR to TTNN backend pipeline.
//
struct TTIRToTTNNBackendPipelineOptions
Expand All @@ -30,6 +82,20 @@ struct TTIRToTTNNBackendPipelineOptions
*this, "enable-grid-set",
llvm::cl::desc("Determine and set max valid grid for Op execution."),
llvm::cl::init(true)};

// Option to override grid size for specific ops.
// The format is a comma separated list of op names and grid sizes.
//
// Example: "op1=2x2,op2=4x4"
//
// This will set the grid size for op1 to 2x2 and op2 to 4x4.
//
// Note: This option is only valid if gridSetPassEnabled is true.
//
Option<llvm::StringMap<SmallVector<int, 2>>, GridSizeOverrideParser>
overrideGridSizes{*this, "override-grid-sizes",
llvm::cl::desc("Override grid sizes for specific ops."),
llvm::cl::init(llvm::StringMap<SmallVector<int, 2>>())};
};

void createTTIRToTTNNBackendPipeline(
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/TTIR/Analysis/GridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@
#include "ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h"

namespace mlir::tt::ttir {

bool GridAnalysis::applyOverrides() {
// Lookup grid size overrides based on location information for current
// operation.
//
if (analysis_input.grid_size_overrides && op->getLoc().isa<NameLoc>()) {
StringRef loc_str_op_name = op->getLoc().cast<NameLoc>().getName();
auto grid_override =
analysis_input.grid_size_overrides->find(loc_str_op_name);
if (grid_override != analysis_input.grid_size_overrides->end()) {
analysis_result.target_rows = grid_override->second[0];
analysis_result.target_columns = grid_override->second[1];
return true;
}
}

return false;
}

void GridAnalysis::analysisImplementation() {
// Placeholder. For now result of analysis is maximum supported grid size.
//
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,9 @@ class TTIRGridSet : public impl::TTIRGridSetBase<TTIRGridSet> {

// Initialize the grid analysis with the max grid size.
//
grid_analysis.init(
GridAnalysisInput(max_grid.getShape()[0], max_grid.getShape()[1]));
grid_analysis.init(GridAnalysisInput(max_grid.getShape()[0],
max_grid.getShape()[1],
&overrideGridSizes));

// Run the grid analysis and get the result.
//
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,11 @@ class ConvertTTIRToTTNN
void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRLayout());

if (options.gridSetPassEnabled) {
pm.addPass(mlir::tt::ttir::createTTIRGridSet());
ttir::TTIRGridSetOptions gridSetOptions;
gridSetOptions.overrideGridSizes = options.overrideGridSizes;
pm.addPass(mlir::tt::ttir::createTTIRGridSet(gridSetOptions));
}

pm.addPass(createTTNNOpenDevice());
Expand Down
28 changes: 28 additions & 0 deletions test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="override-grid-sizes=add_1_0=4x4,add_2_0=4x4" %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("test_ops.py:17_0_0":0:0)
module @pybuda_graph attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>>
// CHECK: #layout3 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout3>
%5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
// CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>
return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4)
} loc(#loc)
} loc(#loc)
#loc1 = loc("test_ops.py:17_0_0":0:4)
#loc2 = loc("test_ops.py:17_0_0":0:6)
#loc3 = loc("test_ops.py:17_0_0":0:3)
#loc4 = loc(unknown)
#loc5 = loc("add_1_0"(#loc1))
#loc6 = loc("add_2_0"(#loc2))
#loc7 = loc("add_0"(#loc3))

0 comments on commit f39d4d2

Please sign in to comment.