Skip to content

Commit

Permalink
Update eltwise flatbuffer to include extra params table, move eltwise…
Browse files Browse the repository at this point in the history
… composite to its own file to match ttnn implementation (#1121)
  • Loading branch information
jnie-TT authored Oct 31, 2024
1 parent 7560daf commit 4dc7502
Show file tree
Hide file tree
Showing 19 changed files with 304 additions and 90 deletions.
5 changes: 5 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,15 @@ enum EltwiseOpType: uint32 {
Cos = 27
}

union EltwiseOpParams {

}

table EltwiseOp {
type: EltwiseOpType;
ins: [tt.target.TensorRef];
out: tt.target.TensorRef;
params: EltwiseOpParams;
}

enum ReductionOpType: uint32 {
Expand Down
6 changes: 5 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
::tt::target::ttnn::EltwiseOpType type;
::tt::target::ttnn::EltwiseOpParams paramsType =
::tt::target::ttnn::EltwiseOpParams::NONE;
::flatbuffers::Offset<void> params = 0;
if constexpr (std::is_same_v<EltwiseOp, AbsOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Abs;
} else if constexpr (std::is_same_v<EltwiseOp, AddOp>) {
Expand Down Expand Up @@ -360,7 +363,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
return ::tt::target::ttnn::CreateEltwiseOpDirect(
*cache.fbb, type, &ins,
cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getOutputs().front())));
getOperandThroughDPSOps(op.getOutputs().front())),
paramsType, params);
}

template <typename ReductionOp>
Expand Down
16 changes: 12 additions & 4 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ struct Env {
#endif
get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true,
bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true,
bool setMatmul1DProgramConfig = true)
bool setMatmul1DProgramConfig = true, bool swapBinaryOperands = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true, true);
return Env(true, true, true, true, true, true);
}
#endif
// TODO(bug #272), determine correct layout by tile shape in the future
Expand All @@ -45,15 +45,21 @@ struct Env {
// TODO(bug #891): ttnn::matmul doesn't chose correct program config.
bool setMatmul1DProgramConfig;

// TODO(bug #1124): We're currently swapping the operands for binary ops
// in runtime if the lhs operand is smaller (and requires broadcast onto the
// rhs operand). We should add this check in the compiler.
bool swapBinaryOperands;

private:
constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig)
bool setMatmul1DProgramConfig, bool swapBinaryOperands)
: ignoreTileShape(ignoreTileShape),
emptyOpForceRowMajor(emptyOpForceRowMajor),
fullOpForceRowMajor(fullOpForceRowMajor),
maxpool2dPreshard(maxpool2dPreshard),
setMatmul1DProgramConfig(setMatmul1DProgramConfig) {}
setMatmul1DProgramConfig(setMatmul1DProgramConfig),
swapBinaryOperands(swapBinaryOperands) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
Expand All @@ -68,6 +74,8 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) {
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
<< "setMatmul1DProgramConfig: " << env.setMatmul1DProgramConfig << "\n";
os << "\t"
<< "swapBinaryOperands: " << env.swapBinaryOperands << "\n";
os << "}";
return os;
}
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig) {
bool setMatmul1DProgramConfig, bool swapBinaryOperands) {
static const Env config(ignoreTileShape, emptyOpForceRowMajor,
fullOpForceRowMajor, maxpool2dPreshard,
setMatmul1DProgramConfig);
setMatmul1DProgramConfig, swapBinaryOperands);
return config;
}
#endif
Expand Down
8 changes: 6 additions & 2 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand All @@ -9,8 +11,10 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/slice.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,12 @@
#include "binary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"

namespace tt::runtime::ttnn::operations::binary {

static void
getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) {
LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs");
*lhs = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*rhs = &(tensorPool.at(op->ins()->Get(1)->global_id()));
DEBUG_ASSERT((*lhs)->is_allocated());
DEBUG_ASSERT((*rhs)->is_allocated());

// Switch the order of operands if the second operand requires broadcast
if ((*rhs)->volume() < (*lhs)->volume()) {
std::swap(*lhs, *rhs);
}
}

static void runEltwiseBinaryOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(
Expand All @@ -48,24 +34,6 @@ static void runEltwiseBinaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseBinaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
Expand Down Expand Up @@ -118,14 +86,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::divide);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument("Unsupported Eltwise Binary operation");
}
Expand Down
47 changes: 47 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "binary_composite.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::binary::composite {

static void runEltwiseBinaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
}
}

} // namespace tt::runtime::ttnn::operations::binary::composite
27 changes: 27 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::binary::composite {

inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum:
case ::tt::target::ttnn::EltwiseOpType::Minimum:
return true;
default:
return false;
}
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::binary::composite

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,12 @@
#include "unary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/copy.hpp"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::unary {

static void
getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **in) {
LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ",
op->ins()->size());
*in = &(tensorPool.at(op->ins()->Get(0)->global_id()));
DEBUG_ASSERT((*in)->is_allocated());
}

static void runEltwiseUnaryOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand All @@ -38,22 +28,6 @@ static void runEltwiseUnaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryWithFastAndApproximateModeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand All @@ -80,10 +54,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Ceil: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::ceil);
break;
Expand Down
File renamed without changes.
42 changes: 42 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "unary_composite.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::unary::composite {

static void runEltwiseUnaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
}
}

} // namespace tt::runtime::ttnn::operations::unary::composite
26 changes: 26 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::unary::composite {

inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt:
return true;
default:
return false;
}
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::unary::composite

#endif
Loading

0 comments on commit 4dc7502

Please sign in to comment.