Skip to content

Commit

Permalink
Remove maxpool2d preshard workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Feb 6, 2025
1 parent b519b1f commit 588581e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 54 deletions.
15 changes: 4 additions & 11 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct Env {
#else
constexpr static Env
#endif
get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true,
get(bool swapBinaryOperands = true,
bool readUpdateIndexFromDeviceForKVCache = true,
bool defaultStrideComputation = true,
bool toLayoutAPIAssumeSingleChip = true,
Expand All @@ -24,13 +24,9 @@ struct Env {
;
#else
{
return Env(true, true, true, true, true, true);
return Env(true, true, true, true, true);
}
#endif
// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
// instead of adding a method in runtime
bool maxpool2dPreshard;

// TODO(bug #1124): We're currently swapping the operands for binary ops
// in runtime if the lhs operand is smaller (and requires broadcast onto the
// rhs operand). We should add this check in the compiler.
Expand Down Expand Up @@ -68,12 +64,11 @@ struct Env {
bool usePaddingPairSignatureWithQueueId;

private:
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands,
constexpr Env(bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache,
bool defaultStrideComputation, bool toLayoutAPIAssumeSingleChip,
bool usePaddingPairSignatureWithQueueId)
: maxpool2dPreshard(maxpool2dPreshard),
swapBinaryOperands(swapBinaryOperands),
: swapBinaryOperands(swapBinaryOperands),
readUpdateIndexFromDeviceForKVCache(
readUpdateIndexFromDeviceForKVCache),
defaultStrideComputation(defaultStrideComputation),
Expand All @@ -84,8 +79,6 @@ struct Env {

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
os << "workaround::Env{\n";
os << "\t"
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
<< "swapBinaryOperands: " << env.swapBinaryOperands << ",\n";
os << "\t"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands,
const Env &Env::get(bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache,
bool defaultStrideComputation,
bool toLayoutAPIAssumeSingleChip,
bool usePaddingPairSignatureWithQueueId) {
static const Env config(maxpool2dPreshard, swapBinaryOperands,
static const Env config(swapBinaryOperands,
readUpdateIndexFromDeviceForKVCache,
defaultStrideComputation, toLayoutAPIAssumeSingleChip,
usePaddingPairSignatureWithQueueId);
Expand Down
42 changes: 1 addition & 41 deletions runtime/lib/ttnn/operations/pool/maxpool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,13 @@
#include "operations/pool/maxpool2d.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/detail/workarounds.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/types.hpp"
#include <optional>

namespace tt::runtime::ttnn::operations::pool {

// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
// instead of adding a method in runtime
template <typename DeviceType>
static ::ttnn::Tensor
preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op,
DeviceType &device, const ::ttnn::Tensor &input) {
const ::ttnn::Shape inputShape =
::tt::runtime::ttnn::operations::utils::toTTNNShape(
*op->in()->desc()->shape());
uint32_t output_height =
1 + (op->input_height() + 2 * op->padding_height() -
op->dilation_height() * (op->kernel_height() - 1) - 1) /
op->stride_height();
uint32_t output_width =
1 + (op->input_width() + 2 * op->padding_width() -
op->dilation_width() * (op->kernel_width() - 1) - 1) /
op->stride_width();

constexpr bool en_ch_padding = false;

auto parallel_config = ::ttnn::operations::conv::determine_parallel_config(
::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(),
op->channels(), output_height, output_width, op->channels(),
device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR,
en_ch_padding);
auto sharded_memory_config = ::ttnn::operations::conv::
create_sharded_memory_config_from_parallel_config(inputShape,
parallel_config, 1);
return ::ttnn::to_memory_config(input, sharded_memory_config, std::nullopt);
}

void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::operations::pool::Pool2DOp<
Expand All @@ -53,15 +21,7 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) {

::ttnn::Tensor input = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(input.is_allocated());
if (workaround::Env::get().maxpool2dPreshard) {
DeviceVariant targetDevice =
context.getTargetDevice(op->device()->global_id());
input = std::visit(
[&](auto &&targetDevice) -> ::ttnn::Tensor {
return preshardForMaxPool2d(op, targetDevice.get(), input);
},
targetDevice);
}

::ttnn::MemoryConfig outMemConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
::ttnn::Tensor out = operation.invoke(
Expand Down

0 comments on commit 588581e

Please sign in to comment.