diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp index e95ede88e65..4dba8314615 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.cpp @@ -44,7 +44,7 @@ std::vector InterleavedToShardedDeviceOperation::create_output_tensors(c operation::ProgramWithCallbacks InterleavedToShardedDeviceOperation::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { const auto& input_tensor = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); - return detail::interleaved_to_sharded_multi_core(input_tensor, output_tensor); + return detail::interleaved_to_sharded_multi_core(input_tensor, output_tensor, this->keep_l1_aligned); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.hpp index 867d9263a76..8c21018b07d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_op.hpp @@ -13,6 +13,7 @@ namespace ttnn::operations::data_movement { struct InterleavedToShardedDeviceOperation { const tt::tt_metal::MemoryConfig output_mem_config; const tt::tt_metal::DataType output_dtype; + const bool keep_l1_aligned = false; void validate(const std::vector& input_tensors) const; std::vector compute_output_shapes(const std::vector& input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp index 5284220b9cc..8bd29eb8656 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp @@ -15,7 +15,7 @@ using namespace tt::tt_metal; namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( - const Tensor& input, const Tensor& output, uint32_t num_slices, uint32_t slice_index) { + const Tensor& input, const Tensor& output, bool keep_l1_aligned, uint32_t num_slices, uint32_t slice_index) { tt::tt_metal::Program program{}; uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width, @@ -71,7 +71,13 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( // TODO: Use a different variable name. Units refers to pages, but this is being used as size num_units_per_shard_width_last = input_unit_size - (tt::round_up(num_units_per_row, input_unit_size) - num_units_per_row); - padded_offset_bytes = align(input_unit_size, input.buffer()->alignment()); + //Adjust accordingly to l1 alignment, do it for all archs + if(keep_l1_aligned){ + padded_offset_bytes = align(input_unit_size, hal.get_alignment(HalMemType::L1)); + } + else { + padded_offset_bytes = align(input_unit_size, input.buffer()->alignment()); + } } @@ -95,7 +101,7 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( .set_globally_allocated_address(*output.buffer()); auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, output_cb_out_config); uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - if (src_is_dram && input_unit_size % dram_alignment != 0 or is_blackhole) { + if (src_is_dram && input_unit_size % dram_alignment != 0 or is_blackhole or keep_l1_aligned) { uint32_t scratch_cb_page_size; //scratchpad going to be used to align DRAM (64B) to L1 (16B) if (is_blackhole) { @@ -246,7 +252,8 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); bool aligned = (src_is_dram ? curr_idx_w % dram_alignment == 0 : true); - aligned = aligned and !(is_blackhole); + //for blackhole and keep_l1_aligned cases, always enforce unaligned kernel call + aligned = aligned and !(is_blackhole) and !(keep_l1_aligned); uint32_t aligned_width_offset, aligned_shard_width, aligned_offset; if (!aligned) { //TODO: is this right, leaving non BH case the same for now, should investigate diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.hpp index a6eef5de40e..7d3b71af80c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.hpp @@ -9,6 +9,6 @@ namespace ttnn::operations::data_movement::detail { -operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(const Tensor &a, const Tensor &output, uint32_t num_slices = 1, uint32_t slice_index = 0); +operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(const Tensor &a, const Tensor &output, bool keep_l1_aligned = false, uint32_t num_slices = 1, uint32_t slice_index = 0); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.cpp index eada8a3e337..50c0d528164 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.cpp @@ -16,11 +16,13 @@ ttnn::Tensor InterleavedToShardedOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const MemoryConfig& sharded_memory_config, - const std::optional& data_type_arg) { + const std::optional& data_type_arg, + const std::optional& keep_l1_aligned) { return operation::run( InterleavedToShardedDeviceOperation{ .output_mem_config = sharded_memory_config, - .output_dtype = data_type_arg.value_or(input_tensor.get_dtype())}, + .output_dtype = data_type_arg.value_or(input_tensor.get_dtype()), + .keep_l1_aligned = keep_l1_aligned.value_or(false)}, {input_tensor}) .at(0); } @@ -32,7 +34,8 @@ ttnn::Tensor InterleavedToShardedOperation::invoke( const std::array shard_shape, const TensorMemoryLayout shard_scheme, const ShardOrientation shard_orientation, - const std::optional& data_type_arg) { + const std::optional& data_type_arg, + const std::optional& keep_l1_aligned) { bool row_wise = shard_orientation == ShardOrientation::ROW_MAJOR; CoreCoord grid_size; CoreRangeSet grid_set; @@ -69,7 +72,8 @@ ttnn::Tensor InterleavedToShardedOperation::invoke( return operation::run( InterleavedToShardedDeviceOperation{ .output_mem_config = sharded_mem_config, - .output_dtype = data_type_arg.value_or(input_tensor.get_dtype())}, + .output_dtype = data_type_arg.value_or(input_tensor.get_dtype()), + .keep_l1_aligned = keep_l1_aligned.value_or(false)}, {input_tensor}) .at(0); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp index 8ef01ebd29f..e93b57ccc00 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp @@ -15,7 +15,8 @@ struct InterleavedToShardedOperation { uint8_t queue_id, const ttnn::Tensor& input_tensor, const MemoryConfig& sharded_memory_config, - const std::optional& data_type_arg); + const std::optional& data_type_arg, + const std::optional& keep_l1_aligned = std::nullopt); static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, @@ -23,7 +24,8 @@ struct InterleavedToShardedOperation { const std::array shard_shape, const TensorMemoryLayout shard_scheme, const ShardOrientation shard_orientation, - const std::optional& data_type_arg); + const std::optional& data_type_arg, + const std::optional& keep_l1_aligned = std::nullopt); }; } // namespace operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.cpp index 6a2a3926035..0480306f7e8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.cpp @@ -29,8 +29,17 @@ void bind_interleaved_to_sharded( tt::tt_metal::TensorMemoryLayout shard_scheme, tt::tt_metal::ShardOrientation shard_orientation, const std::optional& output_dtype, - uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, grid, shard_shape, shard_scheme, shard_orientation, output_dtype); + uint8_t queue_id, + const std::optional& keep_l1_aligned) -> ttnn::Tensor { + return self( + queue_id, + input_tensor, + grid, + shard_shape, + shard_scheme, + shard_orientation, + output_dtype, + keep_l1_aligned); }, py::arg("input_tensor").noconvert(), py::arg("grid"), @@ -40,6 +49,7 @@ void bind_interleaved_to_sharded( py::arg("output_dtype") = std::nullopt, py::kw_only(), py::arg("queue_id") = 0, + py::arg("keep_l1_aligned") = false, }, ttnn::pybind_overload_t{ @@ -47,14 +57,16 @@ void bind_interleaved_to_sharded( const ttnn::Tensor& input_tensor, const MemoryConfig& sharded_memory_config, const std::optional& output_dtype, - uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, sharded_memory_config, output_dtype); + uint8_t queue_id, + const std::optional& keep_l1_aligned) -> ttnn::Tensor { + return self(queue_id, input_tensor, sharded_memory_config, output_dtype, keep_l1_aligned); }, py::arg("input_tensor").noconvert(), py::arg("sharded_memory_config"), py::arg("output_dtype") = std::nullopt, py::kw_only(), py::arg("queue_id") = 0, + py::arg("keep_l1_aligned") = false, }); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp index e964aeb0efa..a1415761d2c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.cpp @@ -51,7 +51,7 @@ operation::ProgramWithCallbacks ShardedToInterleavedDeviceOperation::create_prog const std::vector& input_tensors, std::vector& output_tensors) const { const auto& input_tensor = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); - return detail::sharded_to_interleaved_multi_core(input_tensor, output_tensor); + return detail::sharded_to_interleaved_multi_core(input_tensor, output_tensor, this->is_l1_aligned); } } // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.hpp index 060cb780671..82a11342442 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_op.hpp @@ -13,6 +13,7 @@ namespace ttnn::operations::data_movement { struct ShardedToInterleavedDeviceOperation { const tt::tt_metal::MemoryConfig output_mem_config; const tt::tt_metal::DataType output_dtype; + const bool is_l1_aligned = false; void validate(const std::vector& input_tensors) const; std::vector compute_output_shapes(const std::vector& input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp index 8cba763bd54..aacb6af92cd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp @@ -16,7 +16,7 @@ using namespace tt::tt_metal; namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( - const Tensor& input, const Tensor& output, uint32_t num_slices, uint32_t slice_index) { + const Tensor& input, const Tensor& output, bool is_l1_aligned, uint32_t num_slices, uint32_t slice_index) { tt_metal::Program program{}; uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width, @@ -235,8 +235,8 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); uint32_t padded_shard_width = align(output_unit_size, dst_buffer->alignment()); - if(is_blackhole) { - if(!dst_is_dram) + if(is_blackhole or is_l1_aligned) { + if(!dst_is_dram or is_l1_aligned) padded_shard_width = align(output_unit_size, l1_alignment); } tt_metal::SetRuntimeArgs( diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.hpp index 4a432390c37..29b6a89d499 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.hpp @@ -9,6 +9,9 @@ namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( - const Tensor& a, const Tensor& output, uint32_t num_slices = 1, uint32_t slice_index = 0); - + const Tensor& a, + const Tensor& output, + bool is_l1_aligned = false, + uint32_t num_slices = 1, + uint32_t slice_index = 0); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.cpp index 723ee943894..17b54533d7e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.cpp @@ -15,14 +15,17 @@ ttnn::Tensor ShardedToInterleavedOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, const MemoryConfig& memory_config, - const std::optional& output_dtype) { + const std::optional& output_dtype, + const std::optional& is_l1_aligned) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; auto shard_spec = input_tensor.shard_spec().value(); TT_FATAL(input_tensor.shard_spec().has_value(), "Error"); return operation::run( ShardedToInterleavedDeviceOperation{ - .output_mem_config = memory_config, .output_dtype = output_dtype.value_or(input_tensor.get_dtype())}, + .output_mem_config = memory_config, + .output_dtype = output_dtype.value_or(input_tensor.get_dtype()), + .is_l1_aligned = is_l1_aligned.value_or(false)}, {input_tensor}) .at(0); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp index 1f8b72ec9f3..b06e2d3bf6e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp @@ -14,7 +14,8 @@ struct ShardedToInterleavedOperation { uint8_t queue_id, const ttnn::Tensor& input_tensor, const MemoryConfig& memory_config, - const std::optional& output_dtype); + const std::optional& output_dtype, + const std::optional& is_l1_aligned = std::nullopt); }; } // namespace operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.cpp index 89fb27e7db7..ee08a332a75 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.cpp @@ -27,18 +27,21 @@ void bind_sharded_to_interleaved( const ttnn::Tensor& input_tensor, const std::optional& memory_config, const std::optional& output_dtype, - uint8_t queue_id) -> ttnn::Tensor { + uint8_t queue_id, + const std::optional& is_l1_aligned) -> ttnn::Tensor { return self( queue_id, input_tensor, memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), - output_dtype); + output_dtype, + is_l1_aligned); }, py::arg("input_tensor").noconvert(), py::arg("memory_config") = std::nullopt, py::arg("output_dtype") = std::nullopt, py::kw_only(), py::arg("queue_id") = 0, + py::arg("is_l1_aligned") = false, }); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/device/interleaved_to_sharded_partial_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/device/interleaved_to_sharded_partial_op.cpp index 5f84ae580ce..e7fb61543df 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/device/interleaved_to_sharded_partial_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/device/interleaved_to_sharded_partial_op.cpp @@ -74,7 +74,8 @@ operation::ProgramWithCallbacks InterleavedToShardedPartialDeviceOperation::crea const auto& input_tensor = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); // Will move with sharded ops - return detail::interleaved_to_sharded_multi_core(input_tensor, output_tensor, this->num_slices, this->slice_index); + return detail::interleaved_to_sharded_multi_core( + input_tensor, output_tensor, false, this->num_slices, this->slice_index); } } // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/device/sharded_to_interleaved_partial_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/device/sharded_to_interleaved_partial_op.cpp index 40df37dfd50..de6714ba07b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/device/sharded_to_interleaved_partial_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/device/sharded_to_interleaved_partial_op.cpp @@ -60,7 +60,8 @@ operation::ProgramWithCallbacks ShardedToInterleavedPartialDeviceOperation::crea const auto& input_tensor = input_tensors.at(0); auto& output_tensor = input_tensors[1]; // Will move with sharded ops - return detail::sharded_to_interleaved_multi_core(input_tensor, output_tensor, this->num_slices, this->slice_index); + return detail::sharded_to_interleaved_multi_core( + input_tensor, output_tensor, false, this->num_slices, this->slice_index); } } // namespace ttnn::operations::data_movement