Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Simple shard layout picker #1150

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> {
static GridAttr get(::mlir::MLIRContext *context, std::int64_t rank) {
return GridAttr::get(context, SmallVector<std::int64_t>(rank, 1));
}

uint64_t mutable cNumUsedCores = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutable is used for const attributes, does tablegen create this var as a const?

Copy link
Contributor Author

@nobradovictt nobradovictt Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Util getter I made is const func and this is used as a cache for result of the getter. I might have gone overboard on this one, but I didn't want to repeat this multiplication for every time getNumUsedCores is invoked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I'd preferably call this gridVolume.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was in auto merge mode and didn't stop it, I can rename it in my next change.

uint64_t getNumUsedCores() const {
if (cNumUsedCores != 0) {
return cNumUsedCores;
}

cNumUsedCores = 1;
for (int64_t dim : getShape()) {
cNumUsedCores *= dim;
}
return cNumUsedCores;
}
}];
}

Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace mlir::tt::ttnn {
class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
private:
std::unordered_set<Edge> overrideReshardEdges;
void pickOpShardLayouts(ShardSolver &shardSolver,
const L1ChainConfig &l1ChainConfig);

public:
DFShardingPolicy(
Expand Down
36 changes: 28 additions & 8 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,39 @@ void DFShardingPolicy::run() {
ShardSolver shardSolver = l1ChainConfig.resolveWithSolver(
legalLayouts, usableL1CacheSize, overrideReshardEdges);

// TODO(nobradovic)
// For now dummy fetch first legal(largest grid) for shard spec.
//
for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) {
Operation *op = shardSpec.op;
auto validLayouts = shardSolver.at(op);
shardSolver.set(op, *validLayouts.begin());
}
pickOpShardLayouts(shardSolver, l1ChainConfig);

ShardSolverSolution resolvedShardSolution = shardSolver.finish();
l1ChainConfig.complete(resolvedShardSolution.selectedOpLayout,
resolvedShardSolution.memReconfigEdges);
}
}

void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver,
const L1ChainConfig &l1ChainConfig) {
// TODO(nobradovic)
// Simple picker for now, choose the highest grid size for each op, prefer
// width and height sharding over block sharding.
//
for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) {
Operation *op = shardSpec.op;
ShardSolver::RemainingLayoutAttrs validLayouts = shardSolver.at(op);
const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin());
for (const tt::LayoutAttr &layout : validLayouts) {

if (layout.getGrid().getNumUsedCores() >
selectedLayout->getGrid().getNumUsedCores()) {
selectedLayout = &layout;
} else if (layout.getGrid().getNumUsedCores() ==
selectedLayout->getGrid().getNumUsedCores()) {
if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) {
selectedLayout = &layout;
}
}
}

shardSolver.set(op, *selectedLayout);
}
}

} // namespace mlir::tt::ttnn
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void LegalGridAnalysis::analysisImplementation() {
// Height Sharded
// TODO(odjuricic): Missing affine mapping to actual grid. Need to check with
// runtime implementation on what to produce here.
for (auto height = 2; height <= numCores; ++height) {
for (auto height = 1; height <= numCores; ++height) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
Expand All @@ -175,7 +175,7 @@ void LegalGridAnalysis::analysisImplementation() {
}

// Width Sharded
for (auto width = 2; width <= numCores; ++width) {
for (auto width = 1; width <= numCores; ++width) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
Expand All @@ -196,8 +196,8 @@ void LegalGridAnalysis::analysisImplementation() {
// Pick top largest sharded grids.
std::sort(shardedResults.begin(), shardedResults.end(),
[](tt::LayoutAttr a, tt::LayoutAttr b) {
return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] >
b.getGrid().getShape()[0] * b.getGrid().getShape()[1];
return a.getGrid().getNumUsedCores() >
b.getGrid().getNumUsedCores();
});

analysisResult.insert(
Expand Down
12 changes: 11 additions & 1 deletion lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,17 @@ bool ShardSolver::checkShardCompatible(
bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) <
tensorL1UsageCap * usableL1CacheSize;

return l1UsageValid;
if (!l1UsageValid) {
return false;
}

// Shard compat assumption. Try to keep same shard layout.
//
if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) {
return false;
}

return true;
}

// Returns ShardSolverSolution.
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/mnist_sharding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#loc = loc("MNISTLinear":4294967295:0)
module @"tt-forge-graph" attributes {} {
func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> {
// CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded>
// CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded>
// CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, width_sharded>
// CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, width_sharded>
%0 = tensor.empty() : tensor<1x256xf32> loc(#loc8)
// CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]>
%1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
module attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> {
// CHECK: #[[L1_:.*]] = #tt.memory_space<l1>
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded>
// CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @main(%arg0: tensor<1x32x32xf32, #layout>, %arg1: tensor<1x32x32xf32, #layout>, %arg2: tensor<1x32x32xf32, #layout>) -> tensor<1x32x32xf32, #layout> {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded>
// CHECK: #[[LAYOUT_3:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved>
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<32x32>>>}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
Expand Down
Loading