diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 0bd72d189..9a62b2280 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -173,9 +173,14 @@ class ConcatOpConversionPattern : public OpConversionPattern { LogicalResult matchAndRewrite(ttir::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + int dim = adaptor.getDim(); + if (dim < 0) { + dim += cast(adaptor.getInputs().front().getType()) + .getRank(); + } rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInputs(), adaptor.getOutput(), adaptor.getDim()); + adaptor.getInputs(), adaptor.getOutput(), dim); return success(); } }; diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 87827ffe3..ae2cff496 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -167,9 +167,13 @@ ::mlir::LogicalResult mlir::tt::ttir::ConcatOp::verify() { mlir::cast(inputs.front().getType()); int64_t firstTensorRank = firstTensor.getRank(); + if (dim < 0) { + dim += firstTensorRank; + } + // Check that the dimension `dim` is valid. if (dim < 0 || dim >= firstTensor.getRank()) { - return emitOpError() << "Invalid dimension " << dim + return emitOpError() << "Invalid dimension " << getDim() << " for concatenation."; } diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 02e71fb6c..7c80fcd12 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -159,9 +159,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() { mlir::cast(inputs.front().getType()); int64_t firstTensorRank = firstTensor.getRank(); + if (dim < 0) { + dim += firstTensorRank; + } + // Check that the dimension `dim` is valid. if (dim < 0 || dim >= firstTensor.getRank()) { - return emitOpError() << "Invalid dimension " << dim + return emitOpError() << "Invalid dimension " << getDim() << " for concatenation."; } diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir new file mode 100644 index 000000000..5b93d0c50 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir @@ -0,0 +1,10 @@ +// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s +// CHECK: error: 'ttir.concat' op Invalid dimension 2 for concatenation. +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { + %0 = tensor.empty() : tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 2 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + return %1 : tensor<32x96xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir new file mode 100644 index 000000000..30bf6926b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward() -> tensor<32x224xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %1 = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %2 = tensor.empty() : tensor<32x128xf32> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %3 = tensor.empty() : tensor<32x224xf32> + // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] + %4 = "ttir.concat"(%0, %1, %2, %3) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x128xf32>, tensor<32x224xf32>) -> tensor<32x224xf32> + return %4 : tensor<32x224xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir new file mode 100644 index 000000000..f8a4f2db3 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<32x96xf32> + // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + return %1 : tensor<32x96xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir new file mode 100644 index 000000000..5d3a6fbd6 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir @@ -0,0 +1,10 @@ +// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s +// CHECK: error: 'ttir.concat' op Invalid dimension -3 for concatenation. +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { + %0 = tensor.empty() : tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -3 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + return %1 : tensor<32x96xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_concat.mlir b/test/ttmlir/Dialect/TTNN/concat/simple_concat.mlir similarity index 100% rename from test/ttmlir/Dialect/TTNN/simple_concat.mlir rename to test/ttmlir/Dialect/TTNN/concat/simple_concat.mlir