From 616e1f6e9f17501c28ab187dc30094a8776ceaec Mon Sep 17 00:00:00 2001 From: Pavle Josipovic Date: Tue, 29 Oct 2024 11:52:17 +0000 Subject: [PATCH] #13794: Conv2d BS arbitrary kernel dims Allow arbitrary kernel dimensions in block sharded conv2d. This enabled more torch traces to pass as block sharding is now more viable option is auto-shard codepath for convs. Logic for arbitrary double buffer of activations and weights in block sharded code path is removed. This caused issues with some torch trace examples that used to pass with height sharding, as now auto shard would pick block sharding as a better option, but arbitrary double buffer would cause out-of-memory issues. --- .../ttnn_functional_resnet50_new_conv_api.py | 35 +++++++++++-------- models/demos/vgg/tt/ttnn_vgg.py | 2 ++ .../sweeps/conv2d/short/conv2d_short_sweep.py | 3 -- .../unit_tests/operations/test_new_conv2d.py | 3 -- .../ttnn/operations/conv/conv2d/conv2d.cpp | 2 ++ .../operations/conv/conv2d/conv2d_pybind.cpp | 4 ++- .../operations/conv/conv2d/conv2d_utils.hpp | 3 ++ .../conv/conv2d/device/conv2d_op.cpp | 6 ++-- .../conv/conv2d/device/conv2d_op.hpp | 8 ++++- .../conv2d_op_sharded_program_factory.cpp | 24 ++++++------- 10 files changed, 53 insertions(+), 37 deletions(-) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index c5960b650450..52b342f925a9 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -187,7 +187,12 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_accum_enabled, - enable_act_double_buffer=enable_act_double_buffer, + enable_act_double_buffer=enable_act_double_buffer + if height_sharding + else True + if input_width < 56 + else False, + enable_weights_double_buffer=True if input_width < 56 else False, enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ), @@ -335,6 +340,7 @@ def __call__( transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=enable_act_double_buffer, + enable_weights_double_buffer=True, enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ), @@ -846,6 +852,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard = False height_shard = False + is_gs = is_grayskull() if is_wormhole_b0() and self.batch_size == 20: if is_first_run: reshard = True if not is_wormhole_b0() else False @@ -864,7 +871,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -888,7 +895,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -902,7 +909,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -916,7 +923,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -942,7 +949,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -966,7 +973,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -980,7 +987,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -994,7 +1001,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1008,7 +1015,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1023,7 +1030,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, eltwise_binary_out_in_place=True, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1065,7 +1072,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1089,7 +1096,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1103,7 +1110,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, + enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, enable_split_reader=False, enable_subblock_padding=False, ) diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 2e0d838e5baf..7e9a115582d4 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -104,6 +104,7 @@ def ttnn_vgg16( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ), reshard_if_not_optimal=True, + enable_weights_double_buffer=True, ) tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight @@ -226,6 +227,7 @@ def ttnn_vgg11( shard_layout=( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ), + enable_weights_double_buffer=True, ) tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index 52a417526b1b..0f3176775cd0 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -457,7 +457,6 @@ def test_conv2d_localrun(device, input_spec): [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 6 [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], # 14 [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 15 - [1, 192, 192, 99, 99, 5, 5, 2, 2, 0, 0, 192, False, 1], # 100 [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], # 141 [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], # 170 [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 171 @@ -472,8 +471,6 @@ def test_conv2d_localrun(device, input_spec): [1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], # 292 [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 293 [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], # 294 - [1, 576, 576, 19, 19, 5, 5, 1, 1, 2, 2, 576, False, 1], # 300 - [1, 672, 672, 24, 24, 5, 5, 1, 1, 2, 2, 672, False, 1], # 341 [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], # 347 [1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], # 348 [1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], # 363 diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 47f4af61aee4..0e08ee95e027 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -389,9 +389,6 @@ def test_conv_features( fp32_accum, packer_l1_acc, ): - if shard_layout == ttnn.TensorMemoryLayout.BLOCK_SHARDED and filter > 3: - pytest.skip("Block sharding only supports filter size <= 3") - if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: pytest.skip("Row major layout not compatible with bfloat8_b") diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index fc408a8c76b8..4ea15c5af084 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -200,6 +200,7 @@ std::tuple(module, "Conv2dConfig"); py_conv_config.def( - py::init, std::optional, bool, Layout, bool, bool, bool>(), + py::init, std::optional, bool, Layout, bool, bool, bool, bool>(), py::kw_only(), py::arg("math_fidelity") = MathFidelity::HiFi4, py::arg("dtype") = DataType::BFLOAT16, @@ -340,6 +340,7 @@ void py_bind_conv2d(py::module& module) { py::arg("transpose_shards") = true, py::arg("output_layout") = Layout::TILE, py::arg("enable_act_double_buffer") = false, + py::arg("enable_weights_double_buffer") = false, py::arg("enable_split_reader") = false, py::arg("enable_subblock_padding") = false ); @@ -362,6 +363,7 @@ void py_bind_conv2d(py::module& module) { py_conv_config.def_readwrite("transpose_shards", &Conv2dConfig::transpose_shards); py_conv_config.def_readwrite("output_layout", &Conv2dConfig::output_layout); py_conv_config.def_readwrite("enable_act_double_buffer", &Conv2dConfig::enable_act_double_buffer); + py_conv_config.def_readwrite("enable_weights_double_buffer", &Conv2dConfig::enable_weights_double_buffer); py_conv_config.def_readwrite("enable_split_reader", &Conv2dConfig::enable_split_reader); py_conv_config.def_readwrite("enable_subblock_padding", &Conv2dConfig::enable_subblock_padding); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 3c3198778c37..0a9bf24184c8 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -46,6 +46,7 @@ struct Conv2dConfig { bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false Layout output_layout = Layout::TILE; bool enable_act_double_buffer = false; + bool enable_weights_double_buffer = false; // Used on for block sharded convolutions bool enable_split_reader = false; bool enable_subblock_padding = false; static constexpr auto attribute_names = std::make_tuple( @@ -68,6 +69,7 @@ struct Conv2dConfig { "transpose_shards", "output_layout", "enable_act_double_buffer", + "enable_weights_double_buffer", "enable_split_reader", "enable_subblock_padding"); const auto attribute_values() const { @@ -91,6 +93,7 @@ struct Conv2dConfig { std::cref(this->transpose_shards), std::cref(this->output_layout), std::cref(this->enable_act_double_buffer), + std::cref(this->enable_weights_double_buffer), std::cref(this->enable_split_reader), std::cref(this->enable_subblock_padding)); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index f14101aa7cbf..1ae2e0de1d8b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -60,13 +60,14 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional compute_kernel_config, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height ) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; operation::launch_op( - [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] + [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { using ttnn::operations::experimental::auto_format::FormatParams; auto& a = input_tensors.at(0); @@ -86,7 +87,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optionalarch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); return operation::run_without_autoformat( - OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height + OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height ), input_tensors, optional_input_tensors); @@ -235,6 +236,7 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vect compute_kernel_config, output_tensor, enable_act_double_buffer, + enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index 038144993ab6..391465129deb 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -58,6 +58,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T std::optional compute_kernel_config, Tensor& output, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height); @@ -77,6 +78,7 @@ struct OptimizedConvNew { bool use_shallow_conv_variant; const DeviceComputeKernelConfig compute_kernel_config; bool enable_act_double_buffer; + bool enable_weights_double_buffer; bool enable_split_reader; bool enable_subblock_padding; bool use_non_tile_height; @@ -89,7 +91,7 @@ struct OptimizedConvNew { MemoryConfig out_mem_config, DataType dtype, std::array input_tensor_shape, bool use_shallow_conv_variant, - const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) : + const DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) : output_channels(output_channels), groups(groups), sliding_window_config(sliding_window_config), @@ -104,6 +106,7 @@ struct OptimizedConvNew { use_shallow_conv_variant(use_shallow_conv_variant), compute_kernel_config(compute_kernel_config), enable_act_double_buffer(enable_act_double_buffer), + enable_weights_double_buffer(enable_weights_double_buffer), enable_split_reader(enable_split_reader), enable_subblock_padding(enable_subblock_padding), use_non_tile_height(use_non_tile_height) {} @@ -128,6 +131,7 @@ struct OptimizedConvNew { "input_tensor_shape", "use_shallow_conv_variant", "enable_act_double_buffer", + "enable_weights_double_buffer", "enable_split_reader", "enable_subblock_padding"); const auto attribute_values() const { @@ -144,6 +148,7 @@ struct OptimizedConvNew { std::cref(this->input_tensor_shape), std::cref(this->use_shallow_conv_variant), std::cref(this->enable_act_double_buffer), + std::cref(this->enable_weights_double_buffer), std::cref(this->enable_split_reader), std::cref(this->enable_subblock_padding)); } @@ -162,6 +167,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional compute_kernel_config = std::nullopt, bool enable_act_double_buffer = false, + bool enable_weights_double_buffer = false, bool enable_split_reader = false, bool enable_subblock_padding = false, bool use_non_tile_height = false diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index 9573ff0595c9..514ed9cc35fb 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -359,6 +359,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( Tensor& output, DeviceComputeKernelConfig compute_kernel_config, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) { @@ -766,12 +767,12 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t window_outer; uint32_t window_inner; - if (weight_width_sliced and filter_w == 3) { - window_outer = 1; // window_outer = 1 becasue all of filter window is processed in the inner loop - window_inner = 3; // window_inner = 9 / 3, ie. read 3 width coalesced + if (weight_width_sliced) { + window_outer = 1; + window_inner = filter_h; } else { - window_outer = num_blocks_act_w; // window_outer - window_inner = filter_h * filter_w / num_blocks_act_w; // window_inner + window_outer = num_blocks_act_w; + window_inner = filter_h * filter_w / num_blocks_act_w; } reader_defines["WINDOW_INNER"] = std::to_string(window_inner); @@ -928,9 +929,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks; uint32_t num_act_cb_second_reader_tiles = 0; // TODO: This flag should be set in kernel logic but need this for create_CB - if (a.memory_config().is_sharded() and ((filter_h == 3 and filter_w == 3 and - (stride_h == 1 or stride_h == 2)) or (filter_h == 1 and filter_w == 1 and stride_h == 2)) and weight_width_sliced) { - // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before + if (weight_width_sliced) { + // If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire filter_h x filter_w window before // pushing in reader/writer // TODO: Generalize this to not make this assumption read_window_in_inner_loop = true; @@ -943,8 +943,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( if (fully_buffer_weights) { num_weight_cb_tiles *= window_outer; - } else if (per_core_out_matrix_width_ntiles < 5 && per_core_out_matrix_height_ntiles < 22) { // Q: where are these - // numbers from? + } else if (enable_weights_double_buffer) { num_weight_cb_tiles = num_weight_cb_tiles * 2; } @@ -961,9 +960,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( } else { if (enable_act_double_buffer) { num_act_cb_tiles = num_act_cb_tiles * 2; - } else if (conv_act_size_c / conv_act_c_blocks < 160 && - per_core_out_matrix_height_ntiles < 22) { // Q: where are these numbers from? - num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered } } uint32_t out_block_h_ntiles_padded = num_blocks_act_h_per_core * act_block_h_ntiles; @@ -1685,6 +1681,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( std::optional compute_kernel_config, Tensor& output, bool enable_act_double_buffer, + bool enable_weights_double_buffer, bool enable_split_reader, bool enable_subblock_padding, bool use_non_tile_height) { @@ -1758,6 +1755,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( output, compute_kernel_config.value(), enable_act_double_buffer, + enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height);