Skip to content

Commit

Permalink
[Optimizer] Simple shard layout picker. Simple op input-output shard …
Browse files Browse the repository at this point in the history
…match constraint.
  • Loading branch information
nobradovictt committed Nov 5, 2024
1 parent c038025 commit b452c61
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 18 deletions.
13 changes: 13 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> {
static GridAttr get(::mlir::MLIRContext *context, std::int64_t rank) {
return GridAttr::get(context, SmallVector<std::int64_t>(rank, 1));
}

uint64_t mutable cNumUsedCores = 0;
uint64_t getNumUsedCores() const {
if (cNumUsedCores != 0) {
return cNumUsedCores;
}

cNumUsedCores = 1;
for (int64_t dim : getShape()) {
cNumUsedCores *= dim;
}
return cNumUsedCores;
}
}];
}

Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace mlir::tt::ttnn {
class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
private:
std::unordered_set<Edge> overrideReshardEdges;
void pickOpShardLayouts(ShardSolver &shardSolver,
const L1ChainConfig &l1ChainConfig);

public:
DFShardingPolicy(
Expand Down
36 changes: 28 additions & 8 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,39 @@ void DFShardingPolicy::run() {
ShardSolver shardSolver = l1ChainConfig.resolveWithSolver(
legalLayouts, usableL1CacheSize, overrideReshardEdges);

// TODO(nobradovic)
// For now dummy fetch first legal(largest grid) for shard spec.
//
for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) {
Operation *op = shardSpec.op;
auto validLayouts = shardSolver.at(op);
shardSolver.set(op, *validLayouts.begin());
}
pickOpShardLayouts(shardSolver, l1ChainConfig);

ShardSolverSolution resolvedShardSolution = shardSolver.finish();
l1ChainConfig.complete(resolvedShardSolution.selectedOpLayout,
resolvedShardSolution.memReconfigEdges);
}
}

void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver,
const L1ChainConfig &l1ChainConfig) {
// TODO(nobradovic)
// Simple picker for now, choose the highest grid size for each op, prefer
// width and height sharding over block sharding.
//
for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) {
Operation *op = shardSpec.op;
ShardSolver::RemainingLayoutAttrs validLayouts = shardSolver.at(op);
const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin());
for (const tt::LayoutAttr &layout : validLayouts) {

if (layout.getGrid().getNumUsedCores() >
selectedLayout->getGrid().getNumUsedCores()) {
selectedLayout = &layout;
} else if (layout.getGrid().getNumUsedCores() ==
selectedLayout->getGrid().getNumUsedCores()) {
if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) {
selectedLayout = &layout;
}
}
}

shardSolver.set(op, *selectedLayout);
}
}

} // namespace mlir::tt::ttnn
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void LegalGridAnalysis::analysisImplementation() {
// Height Sharded
// TODO(odjuricic): Missing affine mapping to actual grid. Need to check with
// runtime implementation on what to produce here.
for (auto height = 2; height <= numCores; ++height) {
for (auto height = 1; height <= numCores; ++height) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
Expand All @@ -175,7 +175,7 @@ void LegalGridAnalysis::analysisImplementation() {
}

// Width Sharded
for (auto width = 2; width <= numCores; ++width) {
for (auto width = 1; width <= numCores; ++width) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
Expand All @@ -196,8 +196,8 @@ void LegalGridAnalysis::analysisImplementation() {
// Pick top largest sharded grids.
std::sort(shardedResults.begin(), shardedResults.end(),
[](tt::LayoutAttr a, tt::LayoutAttr b) {
return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] >
b.getGrid().getShape()[0] * b.getGrid().getShape()[1];
return a.getGrid().getNumUsedCores() >
b.getGrid().getNumUsedCores();
});

analysisResult.insert(
Expand Down
12 changes: 11 additions & 1 deletion lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,17 @@ bool ShardSolver::checkShardCompatible(
bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) <
tensorL1UsageCap * usableL1CacheSize;

return l1UsageValid;
if (!l1UsageValid) {
return false;
}

// Shard compat assumption. Try to keep same shard layout.
//
if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) {
return false;
}

return true;
}

// Returns ShardSolverSolution.
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/mnist_sharding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#loc = loc("MNISTLinear":4294967295:0)
module @"tt-forge-graph" attributes {} {
func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> {
// CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded>
// CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded>
// CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, width_sharded>
// CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, width_sharded>
%0 = tensor.empty() : tensor<1x256xf32> loc(#loc8)
// CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]>
%1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
module attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> {
// CHECK: #[[L1_:.*]] = #tt.memory_space<l1>
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded>
// CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @main(%arg0: tensor<1x32x32xf32, #layout>, %arg1: tensor<1x32x32xf32, #layout>, %arg2: tensor<1x32x32xf32, #layout>) -> tensor<1x32x32xf32, #layout> {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded>
// CHECK: #[[LAYOUT_3:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved>
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<<interleaved>, <dram>, <<32x32>>>}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
Expand Down

0 comments on commit b452c61

Please sign in to comment.