Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update eltwise flatbuffer to include extra params, move eltwise composite to own files #1121

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,15 @@ enum EltwiseOpType: uint32 {
Cos = 27
}

union EltwiseOpParams {

}

table EltwiseOp {
type: EltwiseOpType;
ins: [tt.target.TensorRef];
out: tt.target.TensorRef;
params: EltwiseOpParams;
}

enum ReductionOpType: uint32 {
Expand Down
6 changes: 5 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
::tt::target::ttnn::EltwiseOpType type;
::tt::target::ttnn::EltwiseOpParams paramsType =
::tt::target::ttnn::EltwiseOpParams::NONE;
::flatbuffers::Offset<void> params = 0;
if constexpr (std::is_same_v<EltwiseOp, AbsOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Abs;
} else if constexpr (std::is_same_v<EltwiseOp, AddOp>) {
Expand Down Expand Up @@ -360,7 +363,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
return ::tt::target::ttnn::CreateEltwiseOpDirect(
*cache.fbb, type, &ins,
cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getOutputs().front())));
getOperandThroughDPSOps(op.getOutputs().front())),
paramsType, params);
}

template <typename ReductionOp>
Expand Down
16 changes: 12 additions & 4 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ struct Env {
#endif
get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true,
bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true,
bool setMatmul1DProgramConfig = true)
bool setMatmul1DProgramConfig = true, bool swapBinaryOperands = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true, true, true);
return Env(true, true, true, true, true, true);
}
#endif
// TODO(bug #272), determine correct layout by tile shape in the future
Expand All @@ -45,15 +45,21 @@ struct Env {
// TODO(bug #891): ttnn::matmul doesn't chose correct program config.
bool setMatmul1DProgramConfig;

// TODO(bug #1124): We're currently swapping the operands for binary ops
// in runtime if the lhs operand is smaller (and requires broadcast onto the
// rhs operand). We should add this check in the compiler.
bool swapBinaryOperands;

private:
constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig)
bool setMatmul1DProgramConfig, bool swapBinaryOperands)
: ignoreTileShape(ignoreTileShape),
emptyOpForceRowMajor(emptyOpForceRowMajor),
fullOpForceRowMajor(fullOpForceRowMajor),
maxpool2dPreshard(maxpool2dPreshard),
setMatmul1DProgramConfig(setMatmul1DProgramConfig) {}
setMatmul1DProgramConfig(setMatmul1DProgramConfig),
swapBinaryOperands(swapBinaryOperands) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
Expand All @@ -68,6 +74,8 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) {
<< "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n";
os << "\t"
<< "setMatmul1DProgramConfig: " << env.setMatmul1DProgramConfig << "\n";
os << "\t"
<< "swapBinaryOperands: " << env.swapBinaryOperands << "\n";
os << "}";
return os;
}
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor,
bool fullOpForceRowMajor, bool maxpool2dPreshard,
bool setMatmul1DProgramConfig) {
bool setMatmul1DProgramConfig, bool swapBinaryOperands) {
static const Env config(ignoreTileShape, emptyOpForceRowMajor,
fullOpForceRowMajor, maxpool2dPreshard,
setMatmul1DProgramConfig);
setMatmul1DProgramConfig, swapBinaryOperands);
return config;
}
#endif
Expand Down
8 changes: 6 additions & 2 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand All @@ -9,8 +11,10 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/slice.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data_movement/transpose.cpp
${CMAKE_CURRENT_SOURCE_DIR}/deletion/dealloc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,12 @@
#include "binary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"

namespace tt::runtime::ttnn::operations::binary {

static void
getEltwiseBinaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) {
LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs");
*lhs = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*rhs = &(tensorPool.at(op->ins()->Get(1)->global_id()));
DEBUG_ASSERT((*lhs)->is_allocated());
DEBUG_ASSERT((*rhs)->is_allocated());

// Switch the order of operands if the second operand requires broadcast
if ((*rhs)->volume() < (*lhs)->volume()) {
std::swap(*lhs, *rhs);
}
}

static void runEltwiseBinaryOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(
Expand All @@ -48,24 +34,6 @@ static void runEltwiseBinaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseBinaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
Expand Down Expand Up @@ -118,14 +86,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::divide);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument("Unsupported Eltwise Binary operation");
}
Expand Down
47 changes: 47 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "binary_composite.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::binary::composite {

static void runEltwiseBinaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
}
}

} // namespace tt::runtime::ttnn::operations::binary::composite
27 changes: 27 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::binary::composite {

inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum:
case ::tt::target::ttnn::EltwiseOpType::Minimum:
return true;
default:
return false;
}
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::binary::composite

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,12 @@
#include "unary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/copy.hpp"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::unary {

static void
getEltwiseUnaryOPInputTensor(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **in) {
LOG_ASSERT(op->ins()->size() == 1, "Expected 1 input, got ",
op->ins()->size());
*in = &(tensorPool.at(op->ins()->Get(0)->global_id()));
DEBUG_ASSERT((*in)->is_allocated());
}

static void runEltwiseUnaryOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand All @@ -38,22 +28,6 @@ static void runEltwiseUnaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryWithFastAndApproximateModeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand All @@ -80,10 +54,6 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Ceil: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::ceil);
break;
Expand Down
42 changes: 42 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "unary_composite.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"

namespace tt::runtime::ttnn::operations::unary::composite {

static void runEltwiseUnaryCompositeOP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we mean when we say "Composite" op in our runtime?

Copy link
Contributor Author

@jnie-TT jnie-TT Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It reflects the scheme of ttnn (in src/tt-metal/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp). They have a concept of unaryCompositeOp that has a different execution function signature than ordinary unary ops (same for binary_composite). Alot of these ops have extra parameters aside from the input tensor like min/max for clamp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see ok makes sense.

const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(const ::ttnn::Tensor &,
const ::tt::tt_metal::MemoryConfig &)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
}
}

} // namespace tt::runtime::ttnn::operations::unary::composite
26 changes: 26 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_UNARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::unary::composite {

inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt:
return true;
default:
return false;
}
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::unary::composite

#endif
Loading
Loading