Skip to content

Commit

Permalink
Reshape only in row major layout (#687)
Browse files Browse the repository at this point in the history
* Convert to `row_major` layout before calling `reshape`

TTNN supports two kinds of reshapes - row and tile. row reshape works
as expected as long as we don't change last dim.
On the other hand tile reshape requires that output `W` and `H` dims are tile aligned.

For example this `reshape` in tile layout will fail:
1*12*32 -> 12*32

Error that we get is `Unable to reshape a tensor in TILE_LAYOUT
to non-tile height and width! Please convert the tensor to
ROW_MAJOR_LAYOUT first.`

In addition when we managed to do a `reshape` on tile tensor we
found that output tensor did not look the way we expected it to look.
For example:

If we have input row major tensor 4x4 which contains numbers from 0 .. 15
and we reshape it to 2x8 we will get [[0,1 .. 6,7], [8,9 .. 14,15]].
The output is correct.

If we take the same input tensor and convert it to tile layout
and call reshape to Shape((2, 8), (32, 32)) and convert output tensor to
row_major we will get tensor which looks like this:
[[0,1,2,3,0,0,0,0],[4,5,6,7,0,0,0,0]].This is different then row major
reshape.

In order to unblock forge for `reshape` op we will convert all input
tensors to row before calling reshape, and then we will convert output
tensor back to original layout.

Once we have understanding how reshape works for tile we will remove
this and implement proper fix.

* Adding comment

* Moving test to Silicon. Adding negative test for dialect.
  • Loading branch information
mtopalovicTT authored Sep 16, 2024
1 parent 056534a commit 8152560
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 11 deletions.
9 changes: 8 additions & 1 deletion runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,14 @@ vectorToArray(const std::vector<int32_t> &vec) {
template <int32_t Rank>
static ::ttnn::Tensor invoke_reshape(const ::ttnn::Tensor &tensor,
const std::vector<int32_t> &shape) {
return ::ttnn::reshape(tensor, vectorToArray<Rank>(shape));
// TDOO #686 - figure out how to call reshape in tile layout
if (tensor.get_layout() == ::ttnn::Layout::ROW_MAJOR) {
return ::ttnn::reshape(tensor, vectorToArray<Rank>(shape));
}

auto rowMajorTensor = untilize(tensor);
auto res = ::ttnn::reshape(rowMajorTensor, vectorToArray<Rank>(shape));
return tilize(res);
}

static void run(::tt::target::ttnn::ReshapeOp const *op,
Expand Down
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/reshape/reshape_fail_on_dims.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s
// CHECK: error: 'ttir.reshape' op Shape attribute must match the output tensor shape for dimensions that are not -1
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<4x2x32x34xbf16>) -> tensor<2x4x32x34xbf16> {
%0 = tensor.empty() : tensor<2x4x32x34xbf16>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [3: i32, 4: i32, 32: i32, 34: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x34xbf16>, tensor<2x4x32x34xbf16>) -> tensor<2x4x32x34xbf16>
return %1 : tensor<2x4x32x34xbf16>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/reshape/simple_reshape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<4x2x32x34xbf16>) -> tensor<2x4x32x34xbf16> {
%0 = tensor.empty() : tensor<2x4x32x34xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
%1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 34: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x34xbf16>, tensor<2x4x32x34xbf16>) -> tensor<2x4x32x34xbf16>
return %1 : tensor<2x4x32x34xbf16>
}
}
10 changes: 0 additions & 10 deletions test/ttmlir/Dialect/TTNN/simple_reshape.mlir

This file was deleted.

12 changes: 12 additions & 0 deletions test/ttmlir/Silicon/TTNN/reshape/reshape_tile_aligned.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<4x32x2x32xbf16> {
%0 = tensor.empty() : tensor<4x32x2x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
%1 = "ttir.reshape"(%arg0, %0) <{shape = [4: i32, 32: i32, 2: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<4x32x2x32xbf16>) -> tensor<4x32x2x32xbf16>
return %1 : tensor<4x32x2x32xbf16>
}
}
12 changes: 12 additions & 0 deletions test/ttmlir/Silicon/TTNN/reshape/simple_reshape.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<4x2x32x34xbf16>) -> tensor<2x4x32x34xbf16> {
%0 = tensor.empty() : tensor<2x4x32x34xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
%1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 34: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x34xbf16>, tensor<2x4x32x34xbf16>) -> tensor<2x4x32x34xbf16>
return %1 : tensor<2x4x32x34xbf16>
}
}

0 comments on commit 8152560

Please sign in to comment.