diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index 3783e94c1a..8ee305ed44 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -49,17 +49,17 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> { return GridAttr::get(context, SmallVector(rank, 1)); } - uint64_t mutable cNumUsedCores = 0; - uint64_t getNumUsedCores() const { - if (cNumUsedCores != 0) { - return cNumUsedCores; + uint64_t mutable cGridVolume = 0; + uint64_t getGridVolume() const { + if (cGridVolume != 0) { + return cGridVolume; } - cNumUsedCores = 1; + cGridVolume = 1; for (int64_t dim : getShape()) { - cNumUsedCores *= dim; + cGridVolume *= dim; } - return cNumUsedCores; + return cGridVolume; } }]; } diff --git a/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h b/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h index e1a8a6d2bd..3c281679d5 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h @@ -289,6 +289,7 @@ class ShardSolver { const std::unordered_set &overrideReshardEdges); RemainingLayoutAttrs at(Operation *operation) const; void set(Operation *operation, tt::LayoutAttr const &layout); + static bool supportsInterleavedInputShardedOutput(Operation *op); private: const llvm::DenseMap> *legalLayouts; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index a42ec0ea88..d808613b90 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -78,10 +78,8 @@ struct TTIRToTTNNBackendPipelineOptions // Option memReconfigEnabled{ *this, "memreconfig-enabled", - llvm::cl::desc("Memory layout reconfiguration pass. Temp disabled till " - "we support all types " - "of shard specs."), - llvm::cl::init(false)}; + llvm::cl::desc("Memory layout reconfiguration pass."), + llvm::cl::init(true)}; // Specify policy for memory layout analysis. // diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h b/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h index 06074a0a34..e7d9c459f8 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h @@ -121,10 +121,8 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> { ::llvm::cl::init(false)}; ::mlir::Pass::Option memReconfigEnabled{ *this, "memreconfig-enabled", - ::llvm::cl::desc("Memory layout reconfiguration pass. Temp disabled till " - "we support all " - "types of shard specs."), - ::llvm::cl::init(false)}; + ::llvm::cl::desc("Memory layout reconfiguration pass."), + ::llvm::cl::init(true)}; ::mlir::Pass::Option memoryLayoutAnalysisPolicy{ diff --git a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index ebb928cd93..1f2e23de23 100644 --- a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -115,15 +115,19 @@ void DFShardingPolicy::run() { if (l1UsageValid) { // TODO(nobradovic) - // It seems that bunch of TTNN ops have constraints which prevent + // It seems that some TTNN ops have constraints which prevent // them from being sharded if both inputs are interleaved, // so proposal for now is starting a shard chain - // with reshard op(at later phase only when necessary based on op - // type) For this reason we also need to validate that currentOp - // can fit into L1 with its first input sharded. + // with reshard op. For this reason we also need to validate that + // currentOp can fit into L1 with its first input sharded. // bool firstInputL1UsageValid = true; - if (l1ChainConfigs->back().isEmpty()) { + if (l1ChainConfigs->back().isEmpty() && + (!ShardSolver::supportsInterleavedInputShardedOutput( + currentOp) || + overrideReshardEdges.count( + Edge(currentOp->getOperand(0).getDefiningOp(), currentOp, + 0)) > 0)) { RankedTensorType firstOpInputTensorType = mlir::cast(currentOp->getOperand(0) .getDefiningOp() @@ -212,11 +216,11 @@ void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver, const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin()); for (const tt::LayoutAttr &layout : validLayouts) { - if (layout.getGrid().getNumUsedCores() > - selectedLayout->getGrid().getNumUsedCores()) { + if (layout.getGrid().getGridVolume() > + selectedLayout->getGrid().getGridVolume()) { selectedLayout = &layout; - } else if (layout.getGrid().getNumUsedCores() == - selectedLayout->getGrid().getNumUsedCores()) { + } else if (layout.getGrid().getGridVolume() == + selectedLayout->getGrid().getGridVolume()) { if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) { selectedLayout = &layout; } diff --git a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp index 46287e72e2..25997d2b83 100644 --- a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp @@ -196,8 +196,7 @@ void LegalGridAnalysis::analysisImplementation() { // Pick top largest sharded grids. std::sort(shardedResults.begin(), shardedResults.end(), [](tt::LayoutAttr a, tt::LayoutAttr b) { - return a.getGrid().getNumUsedCores() > - b.getGrid().getNumUsedCores(); + return a.getGrid().getGridVolume() > b.getGrid().getGridVolume(); }); analysisResult.insert( diff --git a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp index 7108161192..3893b9d9b1 100644 --- a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp @@ -21,7 +21,8 @@ ShardSolver::ShardSolver( const unsigned usableL1CacheSize, const std::unordered_set &overrideReshardEdges) : legalLayouts(&legalLayouts), shardSpecs(&shardSpecs), - shardedOps(&shardedOps), usableL1CacheSize(usableL1CacheSize) { + shardedOps(&shardedOps), usableL1CacheSize(usableL1CacheSize), + memReconfigEdges(overrideReshardEdges) { pathSets.reserve(shardSpecs.size()); pathSetIds.reserve(shardSpecs.size()); bitsets.reserve(shardedOps.size()); @@ -46,12 +47,6 @@ ShardSolver::ShardSolver( } } - // Insert override resharding edges - // - for (const Edge &edge : overrideReshardEdges) { - insertReshard(edge); - } - // Resolve shard chain. // resolve(); @@ -181,17 +176,27 @@ bool ShardSolver::resolveStep() { return true; } +bool ShardSolver::supportsInterleavedInputShardedOutput(Operation *op) { + // TODO(nobradovic,mbezulj): Add check whether this op type can have sharded + // output from interleaved inputs. For now assuming it can. + // + return true; +} + // We need to check if first op requires sharded inputs and if so, insert // reshard edge, then invalidate all sharding options which would go above L1 // size limits. // void ShardSolver::preprocessFirstOp() { - // TODO(nobradovic): Add check whether this op type can have sharded output - // from interleaved inputs. For now assuming it can not. - // + Operation *firstOp = shardSpecs->front().op; + if (supportsInterleavedInputShardedOutput(firstOp) && + memReconfigEdges.count( + Edge(firstOp->getOperand(0).getDefiningOp(), firstOp, 0)) == 0) { + return; + } + // Insert reshard edge for the first op to start the chain. // - Operation *firstOp = shardSpecs->front().op; Edge shardChainInputEdge = Edge(firstOp->getOperand(0).getDefiningOp(), firstOp, 0 /*operandIndex*/);