From fda5027fbf94e2f908d1cb514e4ff25807fb34ce Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 27 Dec 2024 22:59:06 +0000 Subject: [PATCH] Further removal of Shape/LegacyShape --- .../data_movement/pad/device/pad_op.cpp | 45 ++- .../data_movement/pad/device/pad_op.hpp | 17 +- .../pad/device/pad_program_factory.cpp | 266 ++---------------- .../pad/device/pad_program_factory.hpp | 26 +- .../ttnn/operations/data_movement/pad/pad.cpp | 6 +- ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp | 102 ++++--- ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp | 9 +- 7 files changed, 132 insertions(+), 339 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp index fc75c0ae544..fad1cfe2832 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp @@ -29,21 +29,21 @@ void Pad::validate_with_output_tensors( "On device padding only supports padding at end of dims"); } TT_FATAL( - input_tensor.get_legacy_shape()[0] + this->input_tensor_start[0] <= this->output_tensor_shape[0], + input_tensor.get_padded_shape()[0] + this->input_tensor_start[0] <= this->output_padded_shape[0], "Output size cannot fit input with offset"); TT_FATAL( - input_tensor.get_legacy_shape()[1] + this->input_tensor_start[1] <= this->output_tensor_shape[1], + input_tensor.get_padded_shape()[1] + this->input_tensor_start[1] <= this->output_padded_shape[1], "Output size cannot fit input with offset"); TT_FATAL( - input_tensor.get_legacy_shape()[2] + this->input_tensor_start[2] <= this->output_tensor_shape[2], + input_tensor.get_padded_shape()[2] + this->input_tensor_start[2] <= this->output_padded_shape[2], "Output size cannot fit input with offset"); TT_FATAL( - input_tensor.get_legacy_shape()[3] + this->input_tensor_start[3] <= this->output_tensor_shape[3], + input_tensor.get_padded_shape()[3] + this->input_tensor_start[3] <= this->output_padded_shape[3], "Output size cannot fit input with offset"); if (input_tensor.get_layout() == Layout::TILE) { - TT_FATAL((this->output_tensor_shape[2] % TILE_HEIGHT == 0), "Can only pad tilized tensor with full tiles"); - TT_FATAL((this->output_tensor_shape[3] % TILE_WIDTH == 0), "Can only pad tilized tensor with full tiles"); + TT_FATAL((this->output_padded_shape[2] % TILE_HEIGHT == 0), "Can only pad tilized tensor with full tiles"); + TT_FATAL((this->output_padded_shape[3] % TILE_WIDTH == 0), "Can only pad tilized tensor with full tiles"); TT_FATAL( input_tensor.get_dtype() == DataType::FLOAT32 || input_tensor.get_dtype() == DataType::BFLOAT16, "Cannot pad tilized tensor with specified format"); @@ -62,19 +62,16 @@ void Pad::validate_with_output_tensors( } } -std::vector Pad::compute_output_shapes(const std::vector&) const { - return {this->output_tensor_shape.logical_shape()}; -} - -std::vector Pad::create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { +std::vector Pad::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - return {create_device_tensor( - output_tensor_shape, - input_tensor.get_dtype(), - input_tensor.get_layout(), - input_tensor.device(), - this->output_mem_config)}; + return {TensorSpec( + output_logical_shape, + TensorLayout::fromPaddedShape( + input_tensor.get_dtype(), + PageConfig(input_tensor.get_layout()), + output_mem_config, + output_logical_shape, + output_padded_shape))}; } operation::ProgramWithCallbacks Pad::create_program( @@ -104,22 +101,22 @@ operation::ProgramWithCallbacks Pad::create_program( return {}; } else if (input_w != output_w) { return detail::pad_rm_sharded_width_only( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value); } else if (input_tot_h != output_tot_h) { return detail::pad_rm_sharded_height_only( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value); } else { // for no padding, we just use the height-only padding program return detail::pad_rm_sharded_height_only( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value); } } else { if (use_multicore) { return detail::pad_rm_reader_writer_multi_core_v2( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value); } else { return detail::pad_rm_reader_writer( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value); } } } else if (input_tensor.get_layout() == Layout::TILE) { @@ -128,7 +125,7 @@ operation::ProgramWithCallbacks Pad::create_program( tt::LogType::LogOp, "TILE layout does not have multicore implementation yet. Falling back to 1 core."); } return detail::pad_tile( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + input_tensor, output_tensor, this->output_padded_shape, this->input_tensor_start, this->pad_value); } else { TT_THROW("Unsupported layout for pad"); return {}; diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.hpp index 77462979b45..c32c2430840 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.hpp @@ -12,7 +12,8 @@ namespace ttnn::operations::data_movement { struct Pad { - const tt::tt_metal::LegacyShape output_tensor_shape; + const ttnn::SimpleShape output_logical_shape; + const ttnn::SimpleShape output_padded_shape; const ttnn::SimpleShape input_tensor_start; const float pad_value; const tt::tt_metal::MemoryConfig output_mem_config; @@ -20,16 +21,20 @@ struct Pad { void validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; tt::tt_metal::operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; static constexpr auto attribute_names = std::forward_as_tuple( - "output_tensor_shape", "input_tensor_start", "pad_value", "output_mem_config", "use_multicore"); + "output_logical_shape", + "output_padded_shape", + "input_tensor_start", + "pad_value", + "output_mem_config", + "use_multicore"); const auto attribute_values() const { return std::forward_as_tuple( - this->output_tensor_shape, + this->output_logical_shape, + this->output_padded_shape, this->input_tensor_start, this->pad_value, this->output_mem_config, diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp index a5383efbd40..d60095c5efd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp @@ -20,12 +20,12 @@ namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks pad_rm_reader_writer( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value) { Program program{}; - auto output_shape = output_tensor_shape; + auto output_shape = output_padded_shape; uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input @@ -171,232 +171,10 @@ operation::ProgramWithCallbacks pad_rm_reader_writer( return {std::move(program), override_runtime_args_callback}; } -operation::ProgramWithCallbacks pad_rm_opt( - const Tensor& a, - Tensor& output, - const Shape& output_tensor_shape, - const ttnn::SimpleShape& input_tensor_start, - const float pad_value) { - Program program{}; - - auto output_shape = output_tensor_shape; - - uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); - uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input - TT_ASSERT( - unpadded_row_size_nbytes <= padded_row_size_nbytes, "Padded output tensor size should be >= input tensor size"); - - Device* device = a.device(); - auto dst_buffer_l1 = Buffer::create(device, padded_row_size_nbytes, padded_row_size_nbytes, BufferType::L1); - - // construct const buffer with the pad_value - uint32_t pad_value_const_buffer_size = 32; // noc transfers in chunks of 32 - uint32_t pad_value_const_buffer_nbytes = pad_value_const_buffer_size * a.element_size(); - auto pad_value_const_buffer = - owned_buffer::create(std::vector(pad_value_const_buffer_size, bfloat16(pad_value))); - const Tensor pad_value_const_tensor = - Tensor( - OwnedStorage{pad_value_const_buffer}, - Shape(std::array{1, 1, 1, pad_value_const_buffer_size}), - DataType::BFLOAT16, - Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); - auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); - - Buffer* src0_buffer = a.buffer(); - Buffer* dst_buffer = output.buffer(); - TT_FATAL(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(unpadded_row_size_nbytes); - uint32_t src_log2_stick_size = - src_stick_size_is_power_of_two ? (std::uint32_t)std::log2(unpadded_row_size_nbytes) : 0; - bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(padded_row_size_nbytes); - uint32_t dst_log2_stick_size = - dst_stick_size_is_power_of_two ? (std::uint32_t)std::log2(padded_row_size_nbytes) : 0; - std::vector reader_ct_args = { - (std::uint32_t)src0_is_dram, - (std::uint32_t)dst_is_dram, - (std::uint32_t)src_stick_size_is_power_of_two, - (std::uint32_t)src_log2_stick_size, - (std::uint32_t)dst_stick_size_is_power_of_two, - (std::uint32_t)dst_log2_stick_size}; - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - bfloat16 bfloat_zero = bfloat16(0.0f); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_zero, bfloat_pad_value}); - - CoreRange core({0, 0}, {0, 0}); - KernelHandle reader_kernel_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/pad_dims_rm_interleaved_opt.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); - uint32_t padded_row_diff_size_nbytes = padded_row_size_nbytes - unpadded_row_size_nbytes; - -#if 0 - { - tt::log_debug("src0_buffer_addr: {}", src0_buffer->address()); - tt::log_debug("dst_buffer_addr: {}", dst_buffer->address()); - tt::log_debug("a.shape[0]: {}", a.get_legacy_shape()[0]); - tt::log_debug("out.shape[0]: {}", output_shape[0]); - tt::log_debug("a.shape[1]: {}", a.get_legacy_shape()[1]); - tt::log_debug("out.shape[1]: {}", output_shape[1]); - tt::log_debug("a.shape[2]: {}", a.get_legacy_shape()[2]); - tt::log_debug("out.shape[2]: {}", output_shape[2]); - tt::log_debug("s.shape[3]: {}", a.get_legacy_shape()[3]); - tt::log_debug("out.shape[3]: {}", output_shape[3]); - tt::log_debug("unpadded_row_size_nbytes: {}", unpadded_row_size_nbytes); - tt::log_debug("padded_row_size_nbytes: {}", padded_row_size_nbytes); - tt::log_debug("padded_row_diff_size_nbytes: {}", padded_row_diff_size_nbytes); - tt::log_debug("pad_value_const_tensor_addr: {}", pad_value_const_tensor_addr); - tt::log_debug("pad_value_const_buffer_nbytes: {}", pad_value_const_buffer_nbytes); - tt::log_debug("packed_pad_value: {}", packed_pad_value); - tt::log_debug("dst_buffer_l1_addr: {}", dst_buffer_l1->address()); - } -#endif - - const std::array reader_rt_args = { - src0_buffer->address(), - dst_buffer->address(), - a.get_legacy_shape()[0], - output_shape[0], - a.get_legacy_shape()[1], - output_shape[1], - a.get_legacy_shape()[2], - output_shape[2], - a.get_legacy_shape()[3], - output_shape[3], - unpadded_row_size_nbytes, - padded_row_size_nbytes, - padded_row_diff_size_nbytes, - pad_value_const_tensor_addr, - pad_value_const_buffer_nbytes, - packed_pad_value, - dst_buffer_l1->address()}; - tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_rt_args); - - auto override_runtime_args_callback = [kernel_id = reader_kernel_id]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - CoreCoord core = {0, 0}; - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -operation::ProgramWithCallbacks pad_rm( - const Tensor& a, - Tensor& output, - const Shape& output_tensor_shape, - const ttnn::SimpleShape& input_tensor_start, - const float pad_value) { - tt::tt_metal::Program program{}; - - CoreRange core({0, 0}, {0, 0}); - - // This should allocate a DRAM buffer on the device - tt::tt_metal::Device* device = a.device(); - - auto output_shape = output_tensor_shape; - - tt::tt_metal::Buffer* src0_buffer = a.buffer(); - - uint32_t unpadded_row_size_bytes = a.get_legacy_shape()[3] * a.element_size(); - uint32_t padded_row_size_bytes = output_shape[3] * a.element_size(); - - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - uint32_t src_stick_size = unpadded_row_size_bytes; - uint32_t dst_stick_size = padded_row_size_bytes; - - uint32_t dst_buffer_size = dst_stick_size; - - tt::tt_metal::InterleavedBufferConfig buff_config{ - .device = device, - .size = dst_buffer_size, - .page_size = dst_buffer_size, - .buffer_type = tt::tt_metal::BufferType::L1}; - - auto dst_buffer_l1 = tt::tt_metal::CreateBuffer(buff_config); - - bfloat16 bfloat_pad_value = bfloat16(pad_value); - uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); - - const std::array reader_kernel_args = { - src0_buffer->address(), - dst_buffer->address(), - a.get_legacy_shape()[0], - output_shape[0], - a.get_legacy_shape()[1], - output_shape[1], - a.get_legacy_shape()[2], - output_shape[2], - a.get_legacy_shape()[3], - output_shape[3], - unpadded_row_size_bytes, - padded_row_size_bytes, - padded_row_size_bytes - unpadded_row_size_bytes, - packed_pad_value, - dst_buffer_l1->address()}; - bool src0_is_dram = src0_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool src_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(src_stick_size); - uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (std::uint32_t)std::log2(src_stick_size) : 0; - bool dst_stick_size_is_power_of_two = tt::tt_metal::is_power_of_two_at_least_32(dst_stick_size); - uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (std::uint32_t)std::log2(dst_stick_size) : 0; - std::vector compile_time_args_vec = { - (std::uint32_t)src0_is_dram, - (std::uint32_t)dst_is_dram, - (std::uint32_t)src_stick_size_is_power_of_two, - (std::uint32_t)src_log2_stick_size, - (std::uint32_t)dst_stick_size_is_power_of_two, - (std::uint32_t)dst_log2_stick_size, - - }; - - // Tilized reader - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/pad_dims_rm_interleaved.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(compile_time_args_vec)); - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - - auto override_runtime_args_callback = [kernel_id = unary_reader_kernel_id]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - CoreCoord core = {0, 0}; - - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = dst_buffer->address(); - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - operation::ProgramWithCallbacks pad_tile( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value) { tt::tt_metal::Program program{}; @@ -406,7 +184,7 @@ operation::ProgramWithCallbacks pad_tile( // This should allocate a DRAM buffer on the device tt::tt_metal::Device* device = a.device(); - auto output_shape = output_tensor_shape; + auto output_shape = output_padded_shape; tt::tt_metal::Buffer* src0_buffer = a.buffer(); @@ -419,7 +197,7 @@ operation::ProgramWithCallbacks pad_tile( tt::log_debug("pad_tile"); tt::log_debug("cb_data_format: {}", cb_data_format); tt::log_debug("single_tile_size: {}", single_tile_size); - tt::log_debug("output_tensor_shape: {}", output_tensor_shape); + tt::log_debug("output_tensor_shape: {}", output_padded_shape); tt::log_debug("input_tensor_start: {}", input_tensor_start); tt::log_debug("pad_value: {}", pad_value); @@ -670,12 +448,12 @@ split_across_cores(CoreCoord grid_size, uint32_t nbatch, uint32_t nchannel, uint operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value) { Program program{}; - auto output_shape = output_tensor_shape; + auto output_shape = output_padded_shape; uint32_t unpadded_row_size_nbytes = a.get_legacy_shape()[3] * a.element_size(); uint32_t padded_row_size_nbytes = output_shape[3] * a.element_size(); // Assuming output is same datatype as input @@ -701,12 +479,12 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core( auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); // uint32_t ntiles_h = output_tensor_shape[0] * output_tensor_shape[1] * output_tensor_shape[2] / TILE_HEIGHT; - uint32_t ntiles_h = output_tensor_shape[2] / TILE_HEIGHT; - uint32_t ntiles_w = output_tensor_shape[3] / TILE_WIDTH; + uint32_t ntiles_h = output_padded_shape[2] / TILE_HEIGHT; + uint32_t ntiles_w = output_padded_shape[3] / TILE_WIDTH; auto grid_size = device->compute_with_storage_grid_size(); - uint32_t nbatch = output_tensor_shape[0]; - uint32_t nchannel = output_tensor_shape[1]; + uint32_t nbatch = output_padded_shape[0]; + uint32_t nchannel = output_padded_shape[1]; // first the batch dim is distributed along H, and within each batch then the tiles are distributed. auto [ncores, @@ -1013,16 +791,16 @@ std::vector, std::vector>> get_runtime operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value) { Program program{}; - auto output_shape = output_tensor_shape; + auto output_shape = output_padded_shape; uint32_t W = a.shape()[3], H = a.shape()[2], C = a.shape()[1], N = a.shape()[0]; uint32_t NCH = H * C * N; - uint32_t W_padded = output_tensor_shape[3], H_padded = output_tensor_shape[2], C_padded = output_tensor_shape[1], - N_padded = output_tensor_shape[0]; + uint32_t W_padded = output_padded_shape[3], H_padded = output_padded_shape[2], C_padded = output_padded_shape[1], + N_padded = output_padded_shape[0]; uint32_t NCH_padded = H_padded * C_padded * N_padded; auto& front_pad = input_tensor_start; @@ -1402,16 +1180,16 @@ inline std::vector, std::vector>> get_ operation::ProgramWithCallbacks pad_rm_sharded_height_only( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value) { Program program{}; - auto output_shape = output_tensor_shape; + auto output_shape = output_padded_shape; uint32_t W = a.shape()[3], H = a.shape()[2], C = a.shape()[1], N = a.shape()[0]; uint32_t num_unpadded_sticks = H * C * N; - uint32_t W_padded = output_tensor_shape[3], H_padded = output_tensor_shape[2], C_padded = output_tensor_shape[1], - N_padded = output_tensor_shape[0]; + uint32_t W_padded = output_padded_shape[3], H_padded = output_padded_shape[2], C_padded = output_padded_shape[1], + N_padded = output_padded_shape[0]; uint32_t num_padded_sticks = H_padded * C_padded * N_padded; auto& front_pad = input_tensor_start; @@ -1580,18 +1358,18 @@ operation::ProgramWithCallbacks pad_rm_sharded_height_only( operation::ProgramWithCallbacks pad_rm_sharded_width_only( const Tensor& input_tensor, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { Program program{}; TT_ASSERT( - output.shard_spec().has_value() and output.shard_spec()->shape[1] == output_tensor_shape[-1], + output.shard_spec().has_value() and output.shard_spec()->shape[1] == output_padded_shape[-1], "ttnn.pad: pad_rm_sharded_width_only expects sharded output parameter with shard width equal to the width of " "the requested output tensor. Ensure pad_impl is calling this program factory correctly."); uint32_t W = input_tensor.logical_shape()[-1]; - uint32_t W_padded = output_tensor_shape[3]; + uint32_t W_padded = output_padded_shape[3]; auto unpadded_stick_bytes = W * input_tensor.element_size(); auto padded_stick_bytes = W_padded * input_tensor.element_size(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp index e0bc3ba78bf..15f7778fd0b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp @@ -9,56 +9,42 @@ namespace ttnn::operations::data_movement::detail { tt::tt_metal::operation::ProgramWithCallbacks pad_rm_reader_writer( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, - const ttnn::SimpleShape& input_tensor_start, - const float pad_value); - -tt::tt_metal::operation::ProgramWithCallbacks pad_rm_opt( - const Tensor& a, - Tensor& output, - const Shape& output_tensor_shape, - const ttnn::SimpleShape& input_tensor_start, - const float pad_value); - -tt::tt_metal::operation::ProgramWithCallbacks pad_rm( - const Tensor& a, - Tensor& output, - const Shape& output_tensor_shape, + const ttnn::SimpleShape& output_logical_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value); tt::tt_metal::operation::ProgramWithCallbacks pad_tile( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value); tt::tt_metal::operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value); tt::tt_metal::operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value); tt::tt_metal::operation::ProgramWithCallbacks pad_rm_sharded_height_only( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value); tt::tt_metal::operation::ProgramWithCallbacks pad_rm_sharded_width_only( const Tensor& a, Tensor& output, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index f54a763b638..7edb772d48b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -128,10 +128,10 @@ static ttnn::Tensor pad_impl( !input_tensor.is_sharded() || output_w == output_memory_config.shard_spec->shape[1], "output_w != output_memory_config.shard_spec().shape[1]"); - tt::tt_metal::LegacyShape output_padded_legacy_shape{output_padded_shape}; - + ttnn::SimpleShape output_shape{output_padded_shape}; auto output_tensor = operation::run( - Pad{output_padded_legacy_shape, + Pad{output_shape, + output_shape, ttnn::SimpleShape{input_tensor_start}, value, output_memory_config, diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index 339c919571a..dfd36a108b1 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -4,6 +4,8 @@ #include "tensor_layout.hpp" +#include "ttnn/tensor/tensor_utils.hpp" + namespace tt::tt_metal { namespace { @@ -18,25 +20,27 @@ size_t round_up(size_t value, size_t multiple) { }; Alignment legacyShapeToAlignment( - const ttnn::Shape& shape, const PageConfig& page_config, const MemoryConfig& memory_config) { - const auto& logical_shape = shape.logical_shape(); - const auto& legacy_padded_shape = shape.padded_shape(); - if (logical_shape == legacy_padded_shape) { + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape, + const PageConfig& page_config, + const MemoryConfig& memory_config) { + if (logical_shape == padded_shape) { return Alignment{}; } - const auto rank = legacy_padded_shape.rank(); + const auto rank = padded_shape.rank(); bool alignment_can_be_2D = true; for (int i = rank - 3; i >= 0; i--) { - alignment_can_be_2D &= logical_shape[i] == legacy_padded_shape[i]; + alignment_can_be_2D &= logical_shape[i] == padded_shape[i]; } // SHARDED if (memory_config.shard_spec.has_value()) { TT_FATAL( alignment_can_be_2D, - "Tensor with shape {} cannot be sharded because alignment will have rank greater than 2!", - shape); + "Tensor with shape {} ({}) cannot be sharded because alignment will have rank greater than 2!", + logical_shape, + padded_shape); if (page_config.get_layout() == Layout::ROW_MAJOR) { const auto& shard_spec = memory_config.shard_spec.value(); if (shard_spec.physical_shard_shape.has_value()) { @@ -52,10 +56,10 @@ Alignment legacyShapeToAlignment( ttnn::SmallVector values(std::min((int)rank, 2)); const auto alignment_size = values.size(); if (alignment_size >= 1) { - values[alignment_size - 1] = legacy_padded_shape[-1]; + values[alignment_size - 1] = padded_shape[-1]; } if (alignment_size == 2) { - values[alignment_size - 2] = legacy_padded_shape[-2]; + values[alignment_size - 2] = padded_shape[-2]; } Alignment result(std::move(values)); return result; @@ -64,11 +68,11 @@ Alignment legacyShapeToAlignment( // INTERLEAVED with (deprecated) non-height/width padding // NOTE: Rank > 2 is guaranteed in this case ttnn::SmallVector values(rank); - values[rank - 1] = legacy_padded_shape[-1]; - values[rank - 2] = legacy_padded_shape[-2]; + values[rank - 1] = padded_shape[-1]; + values[rank - 2] = padded_shape[-2]; for (int i = rank - 3; i >= 0; i--) { - values[i] = legacy_padded_shape[i] * values[i + 1]; + values[i] = padded_shape[i] * values[i + 1]; } for (auto& value : values) { @@ -101,15 +105,40 @@ TensorLayout TensorLayout::fromLegacyPaddedShape( dtype, page_config, memory_config, - CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment(legacy_shape, page_config, memory_config)); + CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment( + legacy_shape.logical_shape(), legacy_shape.padded_shape(), page_config, memory_config)); +} + +TensorLayout TensorLayout::fromPaddedShape( + DataType dtype, + const PageConfig& page_config, + const MemoryConfig& memory_config, + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape) { + return TensorLayout( + dtype, + page_config, + memory_config, + CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment(logical_shape, padded_shape, page_config, memory_config)); } void TensorLayout::initialize_alignment() { - if (!alignment_.empty()) { + auto default_alignment = page_config_.create_default_alignment(dtype_, memory_config_); + if (alignment_.empty()) { + alignment_ = default_alignment; return; } - alignment_ = page_config_.create_default_alignment(dtype_, memory_config_); + ttnn::SmallVector result(std::max(alignment_.size(), default_alignment.size()), 1); + for (size_t i = 0; i < alignment_.size(); i++) { + result[i + result.size() - alignment_.size()] = alignment_[i]; + } + for (size_t i = 0; i < default_alignment.size(); i++) { + size_t result_idx = i + result.size() - default_alignment.size(); + result[result_idx] = CMAKE_UNIQUE_NAMESPACE::round_up(result[result_idx], default_alignment[i]); + } + + alignment_ = Alignment(std::move(result)); } void TensorLayout::validate_alignment() const { @@ -310,39 +339,30 @@ Size TensorLayout::compute_page_shape(const Size& physical_size) const { } Strides TensorLayout::compute_strides(const ttnn::SimpleShape& shape) const { - const int rank = static_cast(shape.rank()); - const int alignment_rank = static_cast(alignment_.size()); - - Strides strides(rank, 1); - for (int i = rank - 2; i >= 0; i--) { - strides[i] = strides[i + 1] * shape[i + 1]; - - const int alignment_index = i - (rank - alignment_rank) + 1; - if (alignment_index >= 0) { - strides[i] = CMAKE_UNIQUE_NAMESPACE::round_up(strides[i], alignment_[alignment_index]); - } - } - - return strides; + auto padded_shape = compute_padded_shape(shape); + return tt::tt_metal::compute_strides(padded_shape); } ttnn::SimpleShape TensorLayout::compute_padded_shape(const ttnn::SimpleShape& shape) const { - ttnn::SmallVector padded_shape(shape.rank()); + ttnn::SmallVector padded_shape(std::max(shape.rank(), alignment_.size())); int rank_index = static_cast(shape.rank()) - 1; int alignment_index = static_cast(alignment_.size()) - 1; + int padded_shape_index = static_cast(padded_shape.size() - 1); size_t accum_alignment = 1; - for (; rank_index >= 0 && alignment_index >= 0; rank_index--, alignment_index--) { + for (; alignment_index >= 0; rank_index--, alignment_index--, padded_shape_index--) { + uint32_t shape_value = rank_index >= 0 ? shape[rank_index] : 1; + uint32_t alignment_value = alignment_[alignment_index]; + uint32_t& padded_shape_value = padded_shape[padded_shape_index]; // The last 2 dimensions of a shape are special if (rank_index >= static_cast(shape.rank()) - 2) { - padded_shape[rank_index] = CMAKE_UNIQUE_NAMESPACE::round_up(shape[rank_index], alignment_[alignment_index]); + padded_shape_value = CMAKE_UNIQUE_NAMESPACE::round_up(shape_value, alignment_value); } else { - if (accum_alignment % alignment_[alignment_index] == 0) { + if (accum_alignment % alignment_value == 0) { // Alignment for this dimension is redundant, ignoring - padded_shape[rank_index] = shape[rank_index]; - } else if (alignment_[alignment_index] % accum_alignment == 0) { - padded_shape[rank_index] = - CMAKE_UNIQUE_NAMESPACE::round_up(shape[rank_index], alignment_[alignment_index] / accum_alignment); + padded_shape_value = shape_value; + } else if (alignment_value % accum_alignment == 0) { + padded_shape_value = CMAKE_UNIQUE_NAMESPACE::round_up(shape_value, alignment_value / accum_alignment); } else { TT_THROW( "Padded shape can't be deducted from TensorLayout parameters {} and Shape {}", alignment_, shape); @@ -351,11 +371,11 @@ ttnn::SimpleShape TensorLayout::compute_padded_shape(const ttnn::SimpleShape& sh // Alignment doesn't accumulate on the last dimension of a shape if (rank_index != static_cast(shape.rank()) - 1) { - accum_alignment *= padded_shape[rank_index]; + accum_alignment *= padded_shape_value; } } - for (; rank_index >= 0; rank_index--) { - padded_shape[rank_index] = shape[rank_index]; + for (; rank_index >= 0; rank_index--, padded_shape_index--) { + padded_shape[padded_shape_index] = shape[rank_index]; } return ttnn::SimpleShape(std::move(padded_shape)); } diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp index 2e9b24cb03a..6625bb19ac6 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp @@ -14,7 +14,7 @@ namespace tt::tt_metal { -using Strides = std::vector; +using Strides = ttnn::SmallVector; // TensorLayout describes how a tensor is laid out in memory // It takes datatype, layout (eg. TILE vs. RM), memory (eg. DRAM vs. L1), sharding (ie. how you want to cut your logical @@ -31,6 +31,13 @@ class TensorLayout { const PageConfig& page_config, const MemoryConfig& memory_config, const ttnn::Shape& legacy_shape); + [[deprecated("Use of Padded Shape is deprecated")]] + static TensorLayout fromPaddedShape( + DataType dtype, + const PageConfig& page_config, + const MemoryConfig& memory_config, + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape); Layout get_layout() const { return page_config_.get_layout(); } PageConfig get_page_config() const { return page_config_; }