Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement conversion for stablehlo.select and add Where Op #852

Merged
merged 12 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,29 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
let results = (outs Variadic<AnyRankedTensor>:$results);
}

class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
}];

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints);
}]>
];
}

def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
let summary = "Eltwise where op.";
let description = [{
Eltwise where operation.
}];
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise unary op.";
Expand Down
23 changes: 23 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,29 @@ class TTNN_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

class TTNN_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
}];

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out);
}]>
];
}

def TTNN_WhereOp : TTNN_ElementwiseTernaryOp<"where"> {
let summary = "Eltwise where.";
let description = [{
Eltwise where operation.
}];
}

def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
let summary = "Eltwise absolute.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ enum EltwiseOpType: uint32 {
Remainder = 32,
IsFinite = 33,
Floor = 34,
Where = 35,
}

union EltwiseOpParams {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,8 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RemOp, mlir::tt::ttir::RemainderOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SelectOp, mlir::tt::ttir::WhereOp>>(typeConverter, ctx);
}

void addReduceOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::CosOp, ttnn::CosOp>,
ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
ElementwiseOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Other ops
//
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>,
DefaultOpConversionPattern<ttnn::EmbeddingOp>>(typeConverter,
ctx);
DefaultOpConversionPattern<ttnn::EmbeddingOp>,
DefaultOpConversionPattern<ttnn::WhereOp>>(typeConverter, ctx);

// CCL ops
//
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Expm1;
} else if constexpr (std::is_same_v<EltwiseOp, RemainderOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Remainder;
} else if constexpr (std::is_same_v<EltwiseOp, WhereOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Where;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -719,6 +721,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto sinOp = dyn_cast<SinOp>(op); sinOp) {
return createOperation(cache, createEltwiseOp(cache, sinOp), debugString);
}
if (auto whereOp = dyn_cast<WhereOp>(op); whereOp) {
return createOperation(cache, createEltwiseOp(cache, whereOp), debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "ttnn/operations/data_movement/permute/permute.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
#include "ttnn/operations/eltwise/ternary/where.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/embedding/embedding.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand All @@ -15,6 +16,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/ternary/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp
Expand Down
32 changes: 32 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ternary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::ternary {

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
if (op->type() != ::tt::target::ttnn::EltwiseOpType::Where) {
throw std::invalid_argument("Unsupported Eltwise Ternary operation");
}

ProgramTensorPool &tensorPool = context.getTensorPool();

::ttnn::Tensor *first = nullptr;
::ttnn::Tensor *second = nullptr;
::ttnn::Tensor *third = nullptr;
getEltwiseTernaryOPInputTensors(op, tensorPool, &first, &second, &third);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out =
::ttnn::where(*first, *second, *third, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::ternary
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/ternary/ternary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H
#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ternary {

inline bool isTernaryOp(const ::tt::target::ttnn::EltwiseOp *op) {
return op->ins()->size() == 3;
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::ternary

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "utils.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/workarounds.h"

namespace tt::runtime::ttnn::operations::ternary {

void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **first,
::ttnn::Tensor **second,
::ttnn::Tensor **third) {
LOG_ASSERT(op->ins()->size() == 3, "Expected 3 inputs");
*first = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*second = &(tensorPool.at(op->ins()->Get(1)->global_id()));
*third = &(tensorPool.at(op->ins()->Get(2)->global_id()));
DEBUG_ASSERT((*first)->is_allocated());
DEBUG_ASSERT((*second)->is_allocated());
DEBUG_ASSERT((*third)->is_allocated());
}

} // namespace tt::runtime::ttnn::operations::ternary
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H
#define TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H

#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ternary {
void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **first,
::ttnn::Tensor **second,
::ttnn::Tensor **third);

} // namespace tt::runtime::ttnn::operations::ternary

#endif
6 changes: 6 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "operations/deletion/dealloc.h"
#include "operations/eltwise/binary/binary.h"
#include "operations/eltwise/binary/binary_composite.h"
#include "operations/eltwise/ternary/ternary.h"
#include "operations/eltwise/unary/unary.h"
#include "operations/eltwise/unary/unary_composite.h"
#include "operations/embedding/embedding.h"
Expand Down Expand Up @@ -73,13 +74,18 @@ void ProgramExecutor::runEltwiseOperation(
return operations::binary::run(op, context);
};

auto runTernaryOp = [&]() { return operations::ternary::run(op, context); };

if (operations::unary::isUnaryOp(op)) {
return runUnaryOp();
}

if (operations::binary::isBinaryOp(op)) {
return runBinaryOp();
}
if (operations::ternary::isTernaryOp(op)) {
return runTernaryOp();
}

throw std::invalid_argument("Unsupported Eltwise operation");
}
Expand Down
13 changes: 13 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_select attributes {} {
func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1>
%1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty()
// CHECK: %[[VAL1:[0-9]+]] = "ttir.eq"
// CHECK: %[[SELECT:[0-9]+]] = "ttir.where"(%[[VAL1:[0-9]+]], %arg0, %arg1, %[[EMPTY:[0-9]+]]) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
return %1 : tensor<13x37xf32>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_where.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module @jit_eltwise_where {
func.func public @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = tensor.empty() : tensor<13x37xf32>
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
uazizTT marked this conversation as resolved.
Show resolved Hide resolved
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}}
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]])
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}
}
16 changes: 16 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = tensor.empty() : tensor<13x37xbf16>
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16>
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}}
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]])
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}
11 changes: 11 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,14 @@ func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> {
return %0 : tensor<1xi32>
// CHECK: return [[VAL]] : tensor<1xi32, {{.*}}>
}

func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = tensor.empty() : tensor<13x37xbf16>
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16>
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}}
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]])
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}
Loading