-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reshape only in row major layout (#687)
* Convert to `row_major` layout before calling `reshape` TTNN supports two kinds of reshapes - row and tile. row reshape works as expected as long as we don't change last dim. On the other hand tile reshape requires that output `W` and `H` dims are tile aligned. For example this `reshape` in tile layout will fail: 1*12*32 -> 12*32 Error that we get is `Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.` In addition when we managed to do a `reshape` on tile tensor we found that output tensor did not look the way we expected it to look. For example: If we have input row major tensor 4x4 which contains numbers from 0 .. 15 and we reshape it to 2x8 we will get [[0,1 .. 6,7], [8,9 .. 14,15]]. The output is correct. If we take the same input tensor and convert it to tile layout and call reshape to Shape((2, 8), (32, 32)) and convert output tensor to row_major we will get tensor which looks like this: [[0,1,2,3,0,0,0,0],[4,5,6,7,0,0,0,0]].This is different then row major reshape. In order to unblock forge for `reshape` op we will convert all input tensors to row before calling reshape, and then we will convert output tensor back to original layout. Once we have understanding how reshape works for tile we will remove this and implement proper fix. * Adding comment * Moving test to Silicon. Adding negative test for dialect.
- Loading branch information
1 parent
056534a
commit 8152560
Showing
6 changed files
with
52 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
test/ttmlir/Dialect/TTNN/reshape/reshape_fail_on_dims.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s | ||
// CHECK: error: 'ttir.reshape' op Shape attribute must match the output tensor shape for dimensions that are not -1 | ||
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile> | ||
module attributes {} { | ||
func.func @forward(%arg0: tensor<4x2x32x34xbf16>) -> tensor<2x4x32x34xbf16> { | ||
%0 = tensor.empty() : tensor<2x4x32x34xbf16> | ||
%1 = "ttir.reshape"(%arg0, %0) <{shape = [3: i32, 4: i32, 32: i32, 34: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x34xbf16>, tensor<2x4x32x34xbf16>) -> tensor<2x4x32x34xbf16> | ||
return %1 : tensor<2x4x32x34xbf16> | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s | ||
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile> | ||
module attributes {} { | ||
func.func @forward(%arg0: tensor<4x2x32x34xbf16>) -> tensor<2x4x32x34xbf16> { | ||
%0 = tensor.empty() : tensor<2x4x32x34xbf16> | ||
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] | ||
%1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 34: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x34xbf16>, tensor<2x4x32x34xbf16>) -> tensor<2x4x32x34xbf16> | ||
return %1 : tensor<2x4x32x34xbf16> | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
12 changes: 12 additions & 0 deletions
12
test/ttmlir/Silicon/TTNN/reshape/reshape_tile_aligned.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir | ||
// RUN: FileCheck %s --input-file=%t.mlir | ||
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn | ||
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile> | ||
module attributes {} { | ||
func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<4x32x2x32xbf16> { | ||
%0 = tensor.empty() : tensor<4x32x2x32xbf16> | ||
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] | ||
%1 = "ttir.reshape"(%arg0, %0) <{shape = [4: i32, 32: i32, 2: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<4x32x2x32xbf16>) -> tensor<4x32x2x32xbf16> | ||
return %1 : tensor<4x32x2x32xbf16> | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir | ||
// RUN: FileCheck %s --input-file=%t.mlir | ||
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn | ||
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile> | ||
module attributes {} { | ||
func.func @forward(%arg0: tensor<4x2x32x34xbf16>) -> tensor<2x4x32x34xbf16> { | ||
%0 = tensor.empty() : tensor<2x4x32x34xbf16> | ||
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] | ||
%1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 34: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x34xbf16>, tensor<2x4x32x34xbf16>) -> tensor<2x4x32x34xbf16> | ||
return %1 : tensor<2x4x32x34xbf16> | ||
} | ||
} |