-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#14372: fix transpose hc RM pcc errors + add N-dimensional permute si…
…ngle core implementation (#14388) #14372: fix transpose hc RM pcc errors - Restriction was on alignment, which was hard coded to 16 for L1, which caused PCC issues when read from DRAM #14370: add basic N-d permute code and support both N-d transpose and permute - TODO: make multicore, single-core for now
- Loading branch information
Showing
12 changed files
with
504 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
...perations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
constexpr bool src0_is_dram = (bool) get_compile_time_arg_val(0); | ||
constexpr uint32_t N = get_compile_time_arg_val(1); | ||
constexpr uint32_t page_size = get_compile_time_arg_val(2); | ||
constexpr uint32_t num_rows = get_compile_time_arg_val(3); | ||
|
||
const uint32_t src_addr = get_arg_val<uint32_t>(0); | ||
|
||
const InterleavedAddrGen<src0_is_dram> s0 = { | ||
.bank_base_address = src_addr, | ||
.page_size = page_size | ||
}; | ||
|
||
uint32_t curr_addr = src_addr; | ||
for (uint32_t i = 0; i < num_rows; ++i) { | ||
cb_reserve_back(tt::CB::c_in0, 1); | ||
uint32_t src_buffer_l1_addr = get_write_ptr(tt::CB::c_in0); | ||
noc_async_read_page(i, s0, src_buffer_l1_addr); | ||
noc_async_read_barrier(); | ||
volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(src_buffer_l1_addr); | ||
cb_push_back(tt::CB::c_in0, 1); | ||
} | ||
|
||
} |
61 changes: 61 additions & 0 deletions
61
...perations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
#include "dataflow_api.h" | ||
|
||
|
||
void kernel_main() { | ||
constexpr bool dst_is_dram = (bool) get_compile_time_arg_val(0); | ||
constexpr uint32_t N = get_compile_time_arg_val(1); | ||
constexpr uint32_t page_size = get_compile_time_arg_val(2); | ||
constexpr uint32_t num_rows = get_compile_time_arg_val(3); | ||
|
||
const uint32_t dst_addr = get_arg_val<uint32_t>(0); | ||
|
||
const InterleavedAddrGen<dst_is_dram> s0 = { | ||
.bank_base_address = dst_addr, | ||
.page_size = page_size | ||
}; | ||
|
||
uint32_t input_shape[N], perm[N], dest_strides[N]; | ||
for (uint32_t i = 1; i <= N; i++) { | ||
input_shape[i - 1] = get_arg_val<uint32_t>(i); | ||
perm[i - 1] = get_arg_val<uint32_t>(i + N); | ||
dest_strides[i - 1] = get_arg_val<uint32_t>(i + 2*N); | ||
} | ||
|
||
uint32_t src_buffer_l1_addr = get_write_ptr(tt::CB::c_in0); | ||
uint32_t curr_addr = dst_addr; | ||
for (uint32_t row = 0; row < num_rows; ++row) { | ||
// Compute multi-dimensional index for the source row | ||
uint32_t src_multi_idx[N]; | ||
size_t remaining = row; | ||
for(uint32_t i = 0; i < N - 1; ++i) { | ||
size_t dim = N - 2 - i; // Start from the second last dimension | ||
src_multi_idx[dim] = remaining % input_shape[dim]; | ||
remaining /= input_shape[dim]; | ||
} | ||
src_multi_idx[N - 1] = 0; // Row dimension index | ||
|
||
// Apply permutation to get destination multi-dimensional index | ||
uint32_t dest_multi_idx[N]; | ||
for(uint32_t i = 0; i < N; ++i) { | ||
dest_multi_idx[i] = src_multi_idx[perm[i]]; | ||
} | ||
|
||
// Convert destination multi-dimensional index to linear index | ||
uint32_t dest_linear_idx = 0; | ||
for(uint32_t i = 0; i < N - 1; ++i) { | ||
dest_linear_idx += dest_multi_idx[i] * dest_strides[i]; | ||
} | ||
cb_wait_front(tt::CB::c_in0, 1); | ||
uint32_t l1_read_addr = get_read_ptr(tt::CB::c_in0); | ||
uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0); | ||
noc_async_write(l1_read_addr, dst_noc_addr, page_size); | ||
noc_async_write_barrier(); | ||
cb_pop_front(tt::CB::c_in0, 1); | ||
} | ||
|
||
} |
68 changes: 68 additions & 0 deletions
68
ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <cstdint> | ||
|
||
#include "ttnn/cpp/ttnn/tensor/types.hpp" | ||
#include "ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" | ||
|
||
namespace ttnn::operations::data_movement { | ||
|
||
PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program_factory( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
return SingleCore{}; | ||
} | ||
|
||
void PermuteDeviceOperation::validate_on_program_cache_miss( | ||
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { | ||
TT_FATAL(attributes.dims.size() == tensor_args.input_tensor.get_logical_shape().rank(), | ||
"Permute dimensions must match input tensor rank"); | ||
TT_FATAL(attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1, | ||
"Last dimension of permute must be the last dimension of the input tensor as page-breaking is not supported at the moment"); | ||
TT_FATAL(tensor_args.input_tensor.is_sharded() == false, | ||
"Permute operation does not support sharded input tensor"); | ||
TT_FATAL(tensor_args.input_tensor.get_layout() == Layout::ROW_MAJOR, "Permute operation only supports row-major layout"); | ||
} | ||
|
||
void PermuteDeviceOperation::validate_on_program_cache_hit( | ||
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {} | ||
|
||
PermuteDeviceOperation::shape_return_value_t PermuteDeviceOperation::compute_output_shapes( | ||
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { | ||
SmallVector<uint32_t> shape, padded_shape; | ||
auto input_shape = tensor_args.input_tensor.get_logical_shape(); | ||
shape.reserve(input_shape.rank()); | ||
for (auto dim : attributes.dims) { | ||
shape.push_back(input_shape[dim]); | ||
} | ||
return ttnn::SimpleShape(shape); | ||
} | ||
|
||
PermuteDeviceOperation::tensor_return_value_t PermuteDeviceOperation::create_output_tensors( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
if (tensor_args.optional_output_tensor.has_value()) { | ||
return tensor_args.optional_output_tensor.value(); | ||
} | ||
auto output_shape = compute_output_shapes(operation_attributes, tensor_args); | ||
const auto& input_tensor = tensor_args.input_tensor; | ||
return create_device_tensor( | ||
output_shape, | ||
input_tensor.tensor_attributes->dtype, | ||
input_tensor.tensor_attributes->layout, | ||
input_tensor.device()); | ||
} | ||
|
||
|
||
std::tuple<PermuteDeviceOperation::operation_attributes_t, PermuteDeviceOperation::tensor_args_t> | ||
PermuteDeviceOperation::invoke(const Tensor& input_tensor, const SmallVector<uint32_t>& dims, | ||
const std::optional<MemoryConfig>& memory_config, std::optional<Tensor> optional_output_tensor) { | ||
return { | ||
operation_attributes_t{.dims=dims, | ||
.output_mem_config=memory_config.value_or(input_tensor.memory_config())}, | ||
tensor_args_t{.input_tensor=input_tensor, .optional_output_tensor=optional_output_tensor} | ||
}; | ||
} | ||
|
||
|
||
} // namespace ttnn::operations::data_movement |
Oops, something went wrong.