From a94fffeff2b45aa21964b82190a24d9ce27775b3 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Wed, 30 Oct 2024 16:49:26 +0000 Subject: [PATCH] Added support for stablehlo.remainder op. Added tests. Tested with ttrt. --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 15 +++++++++++++++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 15 +++++++++++++++ include/ttmlir/Target/TTNN/program.fbs | 3 ++- .../StableHLOToTTIR/StableHLOToTTIRPatterns.cpp | 2 ++ lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 1 + lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 4 +++- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 6 ++++++ .../eltwise/binary/binary_composite.cpp | 4 ++++ .../operations/eltwise/binary/binary_composite.h | 1 + .../StableHLOToTTIR/binary/remainder_op.mlir | 12 ++++++++++++ .../binary/remainder/simple_remainder.mlir | 12 ++++++++++++ .../TTNN/perf_unit/test_perf_remainder.mlir | 14 ++++++++++++++ test/ttmlir/Silicon/TTNN/simple_eltwise.mlir | 9 +++++++++ 13 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/binary/remainder_op.mlir create mode 100644 test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir create mode 100644 test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index c87770bef..3e0d27598 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -420,6 +420,21 @@ def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> { }]; } +def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> { + let summary = "Eltwise remainder."; + let description = [{ + Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a + result tensor. + + Example: + + // %lhs: [17, -17, 17, -17] + // %rhs: [3, 3, -3, -3] + %result = "ttir.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64> + // %result: [2, -2, 2, -2] + }]; +} + class TTIR_ReductionOp traits = []> : TTIR_DPSOp { diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 01ebae803..19132cdd5 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -388,6 +388,21 @@ def TTNN_SubtractOp : TTNN_ElementwiseBinaryOp<"subtract"> { }]; } +def TTNN_RemainderOp : TTNN_ElementwiseBinaryOp<"remainder"> { + let summary = "Eltwise remainder."; + let description = [{ + Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a + result tensor. + + Example: + + // %lhs: [17, -17, 17, -17] + // %rhs: [3, 3, -3, -3] + %result = "ttnn.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64> + // %result: [2, -2, 2, -2] + }]; +} + class TTNN_ReductionOp traits = []> : TTNN_Op { let summary = "Reduction op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index cea35447c..598d97b79 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -89,7 +89,8 @@ enum EltwiseOpType: uint32 { Log = 28, Log1p = 29, Expm1 = 30, - Sign = 31 + Sign = 31, + Remainder = 32 } union EltwiseOpParams { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 15b1f086b..d94304827 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -884,6 +884,8 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); } void addReduceOpsConversionPatterns(MLIRContext *ctx, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 42d834634..02159e1b0 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -887,6 +887,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 0582dce37..c29d00757 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -647,7 +647,9 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, - DefaultOpConversionPattern>(typeConverter, ctx); + DefaultOpConversionPattern, + DefaultOpConversionPattern>(typeConverter, + ctx); // Tensor manipulation ops // diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 71d793a00..d999570d2 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -359,6 +359,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Log; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Expm1; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Remainder; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -628,6 +630,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto divOp = dyn_cast(op); divOp) { return createOperation(cache, createEltwiseOp(cache, divOp), debugString); } + if (auto remainderOp = dyn_cast(op); remainderOp) { + return createOperation(cache, createEltwiseOp(cache, remainderOp), + debugString); + } if (auto matmulOp = dyn_cast(op); matmulOp) { return createOperation(cache, createOp(cache, matmulOp), debugString); } diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp index e0dbddf8f..09b154e6e 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -38,6 +38,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum); break; } + case ::tt::target::ttnn::EltwiseOpType::Remainder: { + runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::remainder); + break; + } default: throw std::invalid_argument( "Unsupported Eltwise Binary Composite operation"); diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h index e04059940..47d0e25b2 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h @@ -14,6 +14,7 @@ inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { switch (op->type()) { case ::tt::target::ttnn::EltwiseOpType::Maximum: case ::tt::target::ttnn::EltwiseOpType::Minimum: + case ::tt::target::ttnn::EltwiseOpType::Remainder: return true; default: return false; diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/binary/remainder_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/binary/remainder_op.mlir new file mode 100644 index 000000000..bbca3a3f9 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/binary/remainder_op.mlir @@ -0,0 +1,12 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_eltwise_remainder attributes {} { + func.func public @test_remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = stablehlo.remainder %arg0, %arg1 : tensor<32x32xf32> + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[REM:[0-9]+]] = "ttir.remainder"(%arg0, %arg1, %[[EMPTY]]){{.*}} -> tensor<32x32xf32> + return %0 : tensor<32x32xf32> + // CHECK: return %[[REM]] : tensor<32x32xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir new file mode 100644 index 000000000..281dccfdd --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} + return %1 : tensor<32x32xf32> + // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir new file mode 100644 index 000000000..68375a9e0 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} + return %1 : tensor<32x32xf32> + // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 4b3cd93bf..7e6bcbcda 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -218,3 +218,12 @@ func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } + +func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} + return %1 : tensor<32x32xf32> + // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} +}