Skip to content

Commit

Permalink
[Minor] Negative dim support for concat (#651)
Browse files Browse the repository at this point in the history
* Adding negative dim support for `concat`

Torch supports it. Forge passes negative dim attr for concat.

* Renaming dim to axis to align with tt-forge attr name

* Revert dim attr rename

* Remove test
  • Loading branch information
mtopalovicTT authored Sep 10, 2024
1 parent 6bec67a commit 3f52ccc
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 3 deletions.
7 changes: 6 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,14 @@ class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
LogicalResult
matchAndRewrite(ttir::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int dim = adaptor.getDim();
if (dim < 0) {
dim += cast<RankedTensorType>(adaptor.getInputs().front().getType())
.getRank();
}
rewriter.replaceOpWithNewOp<ttnn::ConcatOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInputs(), adaptor.getOutput(), adaptor.getDim());
adaptor.getInputs(), adaptor.getOutput(), dim);
return success();
}
};
Expand Down
6 changes: 5 additions & 1 deletion lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,13 @@ ::mlir::LogicalResult mlir::tt::ttir::ConcatOp::verify() {
mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
int64_t firstTensorRank = firstTensor.getRank();

if (dim < 0) {
dim += firstTensorRank;
}

// Check that the dimension `dim` is valid.
if (dim < 0 || dim >= firstTensor.getRank()) {
return emitOpError() << "Invalid dimension " << dim
return emitOpError() << "Invalid dimension " << getDim()
<< " for concatenation.";
}

Expand Down
6 changes: 5 additions & 1 deletion lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
int64_t firstTensorRank = firstTensor.getRank();

if (dim < 0) {
dim += firstTensorRank;
}

// Check that the dimension `dim` is valid.
if (dim < 0 || dim >= firstTensor.getRank()) {
return emitOpError() << "Invalid dimension " << dim
return emitOpError() << "Invalid dimension " << getDim()
<< " for concatenation.";
}

Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s
// CHECK: error: 'ttir.concat' op Invalid dimension 2 for concatenation.
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> {
%0 = tensor.empty() : tensor<32x96xf32>
%1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 2 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32>
return %1 : tensor<32x96xf32>
}
}
17 changes: 17 additions & 0 deletions test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward() -> tensor<32x224xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%1 = tensor.empty() : tensor<32x64xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%2 = tensor.empty() : tensor<32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%3 = tensor.empty() : tensor<32x224xf32>
// CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]]
%4 = "ttir.concat"(%0, %1, %2, %3) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x128xf32>, tensor<32x224xf32>) -> tensor<32x224xf32>
return %4 : tensor<32x224xf32>
}
}
11 changes: 11 additions & 0 deletions test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x96xf32>
// CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]]
%1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32>
return %1 : tensor<32x96xf32>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s
// CHECK: error: 'ttir.concat' op Invalid dimension -3 for concatenation.
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> {
%0 = tensor.empty() : tensor<32x96xf32>
%1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -3 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32>
return %1 : tensor<32x96xf32>
}
}
File renamed without changes.

0 comments on commit 3f52ccc

Please sign in to comment.