diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index 04f2b64af5..3b70ee063f 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -48,6 +48,19 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> { static GridAttr get(::mlir::MLIRContext *context, std::int64_t rank) { return GridAttr::get(context, SmallVector(rank, 1)); } + + uint64_t mutable cNumUsedCores = 0; + uint64_t getNumUsedCores() const { + if (cNumUsedCores != 0) { + return cNumUsedCores; + } + + cNumUsedCores = 1; + for (int64_t dim : getShape()) { + cNumUsedCores *= dim; + } + return cNumUsedCores; + } }]; } diff --git a/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h index 6ef8476b00..38edc5ea37 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h @@ -20,6 +20,8 @@ class DFShardingPolicy { llvm::DenseMap> legalLayouts; llvm::DenseMap> *schedule; unsigned usableL1CacheSize = 0; + void pickOpShardLayouts(ShardSolver &shardSolver, + const L1ChainConfig &l1ChainConfig); public: DFShardingPolicy( diff --git a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index 7a7470ad39..e2cdba9226 100644 --- a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -193,14 +193,7 @@ void DFShardingPolicy::run( ShardSolver shardSolver = l1ChainConfig.resolveWithSolver( legalLayouts, usableL1CacheSize, overrideReshardEdges); - // TODO(nobradovic) - // For now dummy fetch first legal(largest grid) for shard spec. - // - for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) { - Operation *op = shardSpec.op; - auto validLayouts = shardSolver.at(op); - shardSolver.set(op, *validLayouts.begin()); - } + pickOpShardLayouts(shardSolver, l1ChainConfig); ShardSolverSolution resolvedShardSolution = shardSolver.finish(); l1ChainConfig.complete(resolvedShardSolution.selectedOpLayout, @@ -208,4 +201,31 @@ void DFShardingPolicy::run( } } +void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver, + const L1ChainConfig &l1ChainConfig) { + // TODO(nobradovic) + // Simple picker for now, choose the highest grid size for each op, prefer + // width and height sharding over block sharding. + // + for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) { + Operation *op = shardSpec.op; + ShardSolver::RemainingLayoutAttrs validLayouts = shardSolver.at(op); + const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin()); + for (const tt::LayoutAttr &layout : validLayouts) { + + if (layout.getGrid().getNumUsedCores() > + selectedLayout->getGrid().getNumUsedCores()) { + selectedLayout = &layout; + } else if (layout.getGrid().getNumUsedCores() == + selectedLayout->getGrid().getNumUsedCores()) { + if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) { + selectedLayout = &layout; + } + } + } + + shardSolver.set(op, *selectedLayout); + } +} + } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp index 512006c387..46287e72e2 100644 --- a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp @@ -165,7 +165,7 @@ void LegalGridAnalysis::analysisImplementation() { // Height Sharded // TODO(odjuricic): Missing affine mapping to actual grid. Need to check with // runtime implementation on what to produce here. - for (auto height = 2; height <= numCores; ++height) { + for (auto height = 1; height <= numCores; ++height) { shardedResults.push_back( shardedBase .withGrid(op->getContext(), tensorType, @@ -175,7 +175,7 @@ void LegalGridAnalysis::analysisImplementation() { } // Width Sharded - for (auto width = 2; width <= numCores; ++width) { + for (auto width = 1; width <= numCores; ++width) { shardedResults.push_back( shardedBase .withGrid(op->getContext(), tensorType, @@ -196,8 +196,8 @@ void LegalGridAnalysis::analysisImplementation() { // Pick top largest sharded grids. std::sort(shardedResults.begin(), shardedResults.end(), [](tt::LayoutAttr a, tt::LayoutAttr b) { - return a.getGrid().getShape()[0] * a.getGrid().getShape()[1] > - b.getGrid().getShape()[0] * b.getGrid().getShape()[1]; + return a.getGrid().getNumUsedCores() > + b.getGrid().getNumUsedCores(); }); analysisResult.insert( diff --git a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp index 6cc7d1effd..982ee8d50f 100644 --- a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp @@ -535,7 +535,17 @@ bool ShardSolver::checkShardCompatible( bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) < tensorL1UsageCap * usableL1CacheSize; - return l1UsageValid; + if (!l1UsageValid) { + return false; + } + + // Shard compat assumption. Try to keep same shard layout. + // + if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) { + return false; + } + + return true; } // Returns ShardSolverSolution. diff --git a/test/ttmlir/Dialect/TTNN/mnist_sharding.mlir b/test/ttmlir/Dialect/TTNN/mnist_sharding.mlir index 04cde9d747..f608341c84 100644 --- a/test/ttmlir/Dialect/TTNN/mnist_sharding.mlir +++ b/test/ttmlir/Dialect/TTNN/mnist_sharding.mlir @@ -3,8 +3,8 @@ #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> - // CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, width_sharded> + // CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, width_sharded> %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir index fb2eaa465c..b1951166a0 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_input_layout_override.mlir @@ -4,7 +4,7 @@ module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> { // CHECK: #[[L1_:.*]] = #tt.memory_space - // CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded> + // CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded> // CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved> %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> diff --git a/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir b/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir index 34eb9bdc56..07868a8c56 100644 --- a/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir +++ b/test/ttmlir/Dialect/TTNN/test_override_reshard_edges.mlir @@ -8,7 +8,7 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @main(%arg0: tensor<1x32x32xf32, #layout>, %arg1: tensor<1x32x32xf32, #layout>, %arg2: tensor<1x32x32xf32, #layout>) -> tensor<1x32x32xf32, #layout> { // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved> - // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, width_sharded> // CHECK: #[[LAYOUT_3:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>, interleaved> %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #layout1> diff --git a/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir b/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir index bca09ff34f..9fc73e0660 100644 --- a/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir +++ b/test/ttmlir/Silicon/TTNN/sharded/mnist_sharding_tiled.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true override-output-layout=matmul_1=1x8:l1:width_sharded,add_2=1x8:l1:width_sharded,relu_3=1x8:l1:width_sharded,matmul_5=1x1:l1:width_sharded,add_6=1x1:l1:width_sharded,softmax_7=1x1:l1:width_sharded" %s > %t.mlir +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn #any_device = #tt.operand_constraint