From 588581e1124874784090c9070d30d7d7833701de Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Thu, 6 Feb 2025 16:23:49 +0000 Subject: [PATCH] Remove maxpool2d preshard workaround --- .../include/tt/runtime/detail/workarounds.h | 15 ++----- runtime/lib/common/workarounds.cpp | 4 +- .../lib/ttnn/operations/pool/maxpool2d.cpp | 42 +------------------ 3 files changed, 7 insertions(+), 54 deletions(-) diff --git a/runtime/include/tt/runtime/detail/workarounds.h b/runtime/include/tt/runtime/detail/workarounds.h index c5fcca38fe..fd2850c065 100644 --- a/runtime/include/tt/runtime/detail/workarounds.h +++ b/runtime/include/tt/runtime/detail/workarounds.h @@ -15,7 +15,7 @@ struct Env { #else constexpr static Env #endif - get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true, + get(bool swapBinaryOperands = true, bool readUpdateIndexFromDeviceForKVCache = true, bool defaultStrideComputation = true, bool toLayoutAPIAssumeSingleChip = true, @@ -24,13 +24,9 @@ struct Env { ; #else { - return Env(true, true, true, true, true, true); + return Env(true, true, true, true, true); } #endif - // TODO(bug #855): Ideally we should have an op that preshards for maxpool2d - // instead of adding a method in runtime - bool maxpool2dPreshard; - // TODO(bug #1124): We're currently swapping the operands for binary ops // in runtime if the lhs operand is smaller (and requires broadcast onto the // rhs operand). We should add this check in the compiler. @@ -68,12 +64,11 @@ struct Env { bool usePaddingPairSignatureWithQueueId; private: - constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands, + constexpr Env(bool swapBinaryOperands, bool readUpdateIndexFromDeviceForKVCache, bool defaultStrideComputation, bool toLayoutAPIAssumeSingleChip, bool usePaddingPairSignatureWithQueueId) - : maxpool2dPreshard(maxpool2dPreshard), - swapBinaryOperands(swapBinaryOperands), + : swapBinaryOperands(swapBinaryOperands), readUpdateIndexFromDeviceForKVCache( readUpdateIndexFromDeviceForKVCache), defaultStrideComputation(defaultStrideComputation), @@ -84,8 +79,6 @@ struct Env { inline std::ostream &operator<<(std::ostream &os, const Env &env) { os << "workaround::Env{\n"; - os << "\t" - << "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n"; os << "\t" << "swapBinaryOperands: " << env.swapBinaryOperands << ",\n"; os << "\t" diff --git a/runtime/lib/common/workarounds.cpp b/runtime/lib/common/workarounds.cpp index e9847eed24..453accd2bb 100644 --- a/runtime/lib/common/workarounds.cpp +++ b/runtime/lib/common/workarounds.cpp @@ -6,12 +6,12 @@ namespace tt::runtime::workaround { #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 -const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands, +const Env &Env::get(bool swapBinaryOperands, bool readUpdateIndexFromDeviceForKVCache, bool defaultStrideComputation, bool toLayoutAPIAssumeSingleChip, bool usePaddingPairSignatureWithQueueId) { - static const Env config(maxpool2dPreshard, swapBinaryOperands, + static const Env config(swapBinaryOperands, readUpdateIndexFromDeviceForKVCache, defaultStrideComputation, toLayoutAPIAssumeSingleChip, usePaddingPairSignatureWithQueueId); diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index b440c4024a..3eb3ef3ab3 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -5,7 +5,6 @@ #include "operations/pool/maxpool2d.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" -#include "tt/runtime/detail/workarounds.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" #include "ttnn/types.hpp" @@ -13,37 +12,6 @@ namespace tt::runtime::ttnn::operations::pool { -// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d -// instead of adding a method in runtime -template -static ::ttnn::Tensor -preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, - DeviceType &device, const ::ttnn::Tensor &input) { - const ::ttnn::Shape inputShape = - ::tt::runtime::ttnn::operations::utils::toTTNNShape( - *op->in()->desc()->shape()); - uint32_t output_height = - 1 + (op->input_height() + 2 * op->padding_height() - - op->dilation_height() * (op->kernel_height() - 1) - 1) / - op->stride_height(); - uint32_t output_width = - 1 + (op->input_width() + 2 * op->padding_width() - - op->dilation_width() * (op->kernel_width() - 1) - 1) / - op->stride_width(); - - constexpr bool en_ch_padding = false; - - auto parallel_config = ::ttnn::operations::conv::determine_parallel_config( - ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(), - op->channels(), output_height, output_width, op->channels(), - device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR, - en_ch_padding); - auto sharded_memory_config = ::ttnn::operations::conv:: - create_sharded_memory_config_from_parallel_config(inputShape, - parallel_config, 1); - return ::ttnn::to_memory_config(input, sharded_memory_config, std::nullopt); -} - void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::operations::pool::Pool2DOp< @@ -53,15 +21,7 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { ::ttnn::Tensor input = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(input.is_allocated()); - if (workaround::Env::get().maxpool2dPreshard) { - DeviceVariant targetDevice = - context.getTargetDevice(op->device()->global_id()); - input = std::visit( - [&](auto &&targetDevice) -> ::ttnn::Tensor { - return preshardForMaxPool2d(op, targetDevice.get(), input); - }, - targetDevice); - } + ::ttnn::MemoryConfig outMemConfig = ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = operation.invoke(