diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index c5960b650450..52b342f925a9 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -187,7 +187,12 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_accum_enabled, - enable_act_double_buffer=enable_act_double_buffer, + enable_act_double_buffer=enable_act_double_buffer + if height_sharding + else True + if input_width < 56 + else False, + enable_weights_double_buffer=True if input_width < 56 else False, enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ), @@ -335,6 +340,7 @@ def __call__( transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=enable_act_double_buffer, + enable_weights_double_buffer=True, enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ), @@ -846,6 +852,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard = False height_shard = False + is_gs = is_grayskull() if is_wormhole_b0() and self.batch_size == 20: if is_first_run: reshard = True if not is_wormhole_b0() else False @@ -864,7 +871,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -888,7 +895,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -902,7 +909,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -916,7 +923,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -942,7 +949,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -966,7 +973,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -980,7 +987,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -994,7 +1001,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1008,7 +1015,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1023,7 +1030,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, eltwise_binary_out_in_place=True, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1065,7 +1072,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1089,7 +1096,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1103,7 +1110,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 2e0d838e5baf..7e9a115582d4 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -104,6 +104,7 @@ def ttnn_vgg16( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ), reshard_if_not_optimal=True, + enable_weights_double_buffer=True, ) tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight @@ -226,6 +227,7 @@ def ttnn_vgg11( shard_layout=( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ), + enable_weights_double_buffer=True, ) tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index 52a417526b1b..0f3176775cd0 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -457,7 +457,6 @@ def test_conv2d_localrun(device, input_spec): [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6 [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14 [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15 - [1, 192, 192, 99, 99, 5, 5, 2, 2, 0, 0, 192, False, 1], # 100 [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141 [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170 [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171 @@ -472,8 +471,6 @@ def test_conv2d_localrun(device, input_spec): [1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], # 292 [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 293 [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], # 294 - [1, 576, 576, 19, 19, 5, 5, 1, 1, 2, 2, 576, False, 1], # 300 - [1, 672, 672, 24, 24, 5, 5, 1, 1, 2, 2, 672, False, 1], # 341 [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], # 347 [1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], # 348 [1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], # 363 diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 47f4af61aee4..0e08ee95e027 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -389,9 +389,6 @@ def test_conv_features( fp32_accum, packer_l1_acc, ): - if shard_layout == ttnn.TensorMemoryLayout.BLOCK_SHARDED and filter > 3: - pytest.skip("Block sharding only supports filter size <= 3") - if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: pytest.skip("Row major layout not compatible with bfloat8_b") diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index fc408a8c76b8..4ea15c5af084 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -200,6 +200,7 @@ std::tuple(module, "Conv2dConfig"); py_conv_config.def( - py::init, std::optional, bool, Layout, bool, bool, bool>(), + py::init, std::optional, bool, Layout, bool, bool, bool, bool>(), py::kw_only(), py::arg("math_fidelity") = MathFidelity::HiFi4, py::arg("dtype") = DataType::BFLOAT16, @@ -340,6 +340,7 @@ void py_bind_conv2d(py::module& module) { py::arg("transpose_shards") = true, py::arg("output_layout") = Layout::TILE, py::arg("enable_act_double_buffer") = false, + py::arg("enable_weights_double_buffer") = false, py::arg("enable_split_reader") = false, py::arg("enable_subblock_padding") = false ); @@ -362,6 +363,7 @@ void py_bind_conv2d(py::module& module) { py_conv_config.def_readwrite("transpose_shards", &Conv2dConfig::transpose_shards); py_conv_config.def_readwrite("output_layout", &Conv2dConfig::output_layout); py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer); + py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer); py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader); py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 3c3198778c37..0a9bf24184c8 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -46,6 +46,7 @@ struct Conv2dConfig { bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false Layout output_layout = Layout::TILE; bool enable_act_double_buffer = false; + bool enable_weights_double_buffer = false; // Used on for block sharded convolutions bool enable_split_reader = false; bool enable_subblock_padding = false; static constexpr auto attribute_names = std::make_tuple( @@ -68,6 +69,7 @@ struct Conv2dConfig { "transpose_shards", "output_layout", "enable_act_double_buffer", + "enable_weights_double_buffer", "enable_split_reader", "enable_subblock_padding"); const auto attribute_values() const { @@ -91,6 +93,7 @@ struct Conv2dConfig { std::cref(this->transpose_shards), std::cref(this->output_layout), std::cref(this->enable_act_double_buffer), + std::cref(this->enable_weights_double_buffer), std::cref(this->enable_split_reader), std::cref(this->enable_subblock_padding)); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index f14101aa7cbf..1ae2e0de1d8b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -60,13 +60,14 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional compute_kernel_config, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height ) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; operation::launch_op( - [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] + [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { using ttnn::operations::experimental::auto_format::FormatParams; auto& a = input_tensors.at(0); @@ -86,7 +87,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optionalarch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); return operation::run_without_autoformat( - OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height + OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height ), input_tensors, optional_input_tensors); @@ -235,6 +236,7 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vect compute_kernel_config, output_tensor, enable_act_double_buffer, + enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index 038144993ab6..391465129deb 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -58,6 +58,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T std::optional compute_kernel_config, Tensor& output, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height); @@ -77,6 +78,7 @@ struct OptimizedConvNew { bool use_shallow_conv_variant; const DeviceComputeKernelConfig compute_kernel_config; bool enable_act_double_buffer; + bool enable_weights_double_buffer; bool enable_split_reader; bool enable_subblock_padding; bool use_non_tile_height; @@ -89,7 +91,7 @@ struct OptimizedConvNew { MemoryConfig out_mem_config, DataType dtype, std::array input_tensor_shape, bool use_shallow_conv_variant, - const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) : + const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) : output_channels(output_channels), groups(groups), sliding_window_config(sliding_window_config), @@ -104,6 +106,7 @@ struct OptimizedConvNew { use_shallow_conv_variant(use_shallow_conv_variant), compute_kernel_config(compute_kernel_config), enable_act_double_buffer(enable_act_double_buffer), + enable_weights_double_buffer(enable_weights_double_buffer), enable_split_reader(enable_split_reader), enable_subblock_padding(enable_subblock_padding), use_non_tile_height(use_non_tile_height) {} @@ -128,6 +131,7 @@ struct OptimizedConvNew { "input_tensor_shape", "use_shallow_conv_variant", "enable_act_double_buffer", + "enable_weights_double_buffer", "enable_split_reader", "enable_subblock_padding"); const auto attribute_values() const { @@ -144,6 +148,7 @@ struct OptimizedConvNew { std::cref(this->input_tensor_shape), std::cref(this->use_shallow_conv_variant), std::cref(this->enable_act_double_buffer), + std::cref(this->enable_weights_double_buffer), std::cref(this->enable_split_reader), std::cref(this->enable_subblock_padding)); } @@ -162,6 +167,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional compute_kernel_config = std::nullopt, bool enable_act_double_buffer = false, + bool enable_weights_double_buffer = false, bool enable_split_reader = false, bool enable_subblock_padding = false, bool use_non_tile_height = false diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index 9573ff0595c9..514ed9cc35fb 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -359,6 +359,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( Tensor& output, DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) { @@ -766,12 +767,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t window_outer; uint32_t window_inner; - if (weight_width_sliced and filter_w == 3) { - window_outer = 1; // window_outer = 1 becasue all of filter window is processed in the inner loop - window_inner = 3; // window_inner = 9 / 3, ie. read 3 width coalesced + if (weight_width_sliced) { + window_outer = 1; + window_inner = filter_h; } else { - window_outer = num_blocks_act_w; // window_outer - window_inner = filter_h * filter_w / num_blocks_act_w; // window_inner + window_outer = num_blocks_act_w; + window_inner = filter_h * filter_w / num_blocks_act_w; } reader_defines["WINDOW_INNER"] = std::to_string(window_inner); @@ -928,9 +929,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks; uint32_t num_act_cb_second_reader_tiles = 0; // TODO: This flag should be set in kernel logic but need this for create_CB - if (a.memory_config().is_sharded() and ((filter_h == 3 and filter_w == 3 and - (stride_h == 1 or stride_h == 2)) or (filter_h == 1 and filter_w == 1 and stride_h == 2)) and weight_width_sliced) { - // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before + if (weight_width_sliced) { + // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire filter_h x filter_w window before // pushing in reader/writer // TODO: Generalize this to not make this assumption read_window_in_inner_loop = true; @@ -943,8 +943,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( if (fully_buffer_weights) { num_weight_cb_tiles *= window_outer; - } else if (per_core_out_matrix_width_ntiles < 5 && per_core_out_matrix_height_ntiles < 22) { // Q: where are these - // numbers from? + } else if (enable_weights_double_buffer) { num_weight_cb_tiles = num_weight_cb_tiles * 2; } @@ -961,9 +960,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( } else { if (enable_act_double_buffer) { num_act_cb_tiles = num_act_cb_tiles * 2; - } else if (conv_act_size_c / conv_act_c_blocks < 160 && - per_core_out_matrix_height_ntiles < 22) { // Q: where are these numbers from? - num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered } } uint32_t out_block_h_ntiles_padded = num_blocks_act_h_per_core * act_block_h_ntiles; @@ -1685,6 +1681,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( std::optional compute_kernel_config, Tensor& output, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) { @@ -1758,6 +1755,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( output, compute_kernel_config.value(), enable_act_double_buffer, + enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height);