diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 39f7adeae11..43fa209acb0 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -8,7 +8,7 @@ import pytest import math -from models.utility_functions import is_wormhole_b0 +from models.utility_functions import is_wormhole_b0, is_grayskull, is_x2_harvested from tests.ttnn.utils_for_testing import assert_with_pcc import ttnn @@ -34,8 +34,15 @@ def run_max_pool( if shard_scheme != ttnn.TensorMemoryLayout.WIDTH_SHARDED: if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w: pytest.skip("Invalid case") - if (kernel_h == 3 and pad_h != 1) or (kernel_h == 2 and pad_h != 0): - pytest.skip("kernel size and padding combination not supported") + + if ( + (kernel_h == 13 and pad_h != 6) + or (kernel_h == 9 and pad_h != 4) + or (kernel_h == 5 and pad_h != 2) + or (kernel_h == 3 and pad_h != 1) + or (kernel_h == 2 and pad_h != 0) + ): + pytest.skip("kernel size and padding combination not supported") out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1 out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1 @@ -47,15 +54,57 @@ def run_max_pool( if in_c % 16 != 0: pytest.skip("Current maxpool writer needs nchannels to be multiple of 16!") if in_c == 16 and dtype == ttnn.bfloat8_b and in_n * in_h * in_w > 600000: - pytest.skip("This case runs out of memory on Grayskull") + pytest.skip("This case runs out of memory") if in_n > 16 and in_c > 64 and dtype == ttnn.bfloat8_b and is_wormhole_b0(): pytest.skip("This case runs out of memory on Wormhole b0") + if ( + stride == (1, 1) + and (act_shape == [16, 64, 112, 112] or act_shape == [4, 16, 1056, 160] or act_shape == [16, 16, 528, 80]) + and is_wormhole_b0() + ): + pytest.skip("This case runs out of memory on Wormhole b0") + if stride == (1, 1) and act_shape == [8, 16, 528, 80] and is_grayskull(): + pytest.skip("This case runs out of memory on Grayskull") + if kernel_h > 3 and kernel_w > 3 and act_shape == [16, 64, 112, 112] and is_grayskull(): + pytest.skip("This case runs out of memory on Grayskull") + if kernel_size == (13, 13) and act_shape == [128, 32, 132, 20] and is_grayskull(): + pytest.skip("This case runs out of memory on Grayskull") + if kernel_h > 5 and kernel_w > 5 and act_shape == [16, 64, 112, 112] and is_x2_harvested(device): + pytest.skip("This case runs out of memory on Wormhole X2") + if stride == (1, 1) and act_shape == [128, 32, 132, 20] and is_x2_harvested(device): + pytest.skip("This case runs out of memory on Wormhole X2") + if stride == (1, 1) and kernel_size == (13, 13) and act_shape == [32, 32, 264, 40] and is_x2_harvested(device): + pytest.skip("This case runs out of memory on Wormhole X2") + if ( + dtype == ttnn.bfloat8_b + and (act_shape == [4, 16, 1056, 160] or act_shape == [16, 16, 528, 80]) + and is_x2_harvested(device) + ): + pytest.skip("This case runs out of memory on Wormhole X2") if shard_scheme == ttnn.TensorMemoryLayout.WIDTH_SHARDED: if in_c < max_cores: pytest.skip("Width sharding requires channles >= cores") if in_c / max_cores < 16: pytest.skip("Width sharding requires large enough channels to shard (at least 16 per core)") + if ( + kernel_size == (13, 13) + and (act_shape == [8, 4096, 10, 16] or act_shape == [1, 32768, 10, 10]) + and is_grayskull() + ): + pytest.skip("This case runs out of memory on Grayskull") + if ( + stride == (1, 1) + and kernel_h > 5 + and kernel_w > 5 + and (act_shape == [4, 1024, 40, 40] or act_shape == [2, 2048, 40, 40] or act_shape == [8, 4096, 10, 16]) + and is_x2_harvested(device) + ): + pytest.skip("This case runs out of memory on Wormhole X2") + if kernel_h > 5 and kernel_w > 5 and act_shape == [8, 4096, 10, 16] and is_x2_harvested(device): + pytest.skip("This case runs out of memory on Wormhole X2") + if kernel_size == (13, 13) and act_shape == [1, 32768, 10, 10] and is_x2_harvested(device): + pytest.skip("This case runs out of memory on Wormhole X2") if shard_scheme == ttnn.TensorMemoryLayout.BLOCK_SHARDED: if in_c < cores_x: @@ -150,7 +199,12 @@ def run_max_pool( output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1]) output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W - passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch) + + pcc_thresh = 1.0 + if dtype == ttnn.bfloat8_b: + pcc_thresh = 0.9997 + + passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch, pcc_thresh) logger.debug(f"Passing: {passing}, PCC: {pcc}") @@ -213,6 +267,8 @@ def run_max_pool( [1, 256, 56, 56], [1, 512, 28, 28], [1, 512, 14, 14], + # wide yolo kernel + [1, 512, 10, 10], ) ), ) @@ -221,6 +277,9 @@ def run_max_pool( ( (2, 2), (3, 3), + (5, 5), + (9, 9), + (13, 13), ), ) @pytest.mark.parametrize( @@ -228,14 +287,26 @@ def run_max_pool( ( (0, 0), (1, 1), + (2, 2), + (4, 4), + (6, 6), ), ) @pytest.mark.parametrize( "stride", - ((2, 2),), + ( + (1, 1), + (2, 2), + ), ) @pytest.mark.parametrize("dilation", ((1, 1),)) ## default -@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) +@pytest.mark.parametrize( + "dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) def test_run_max_pool( act_shape, kernel_size, @@ -249,26 +320,6 @@ def test_run_max_pool( run_max_pool(act_shape, kernel_size, padding, stride, dilation, device, dtype) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) -@pytest.mark.parametrize( - "act_shape", ## NCHW - ( - ( - [8, 64, 112, 112], - [1, 512, 10, 10], - ) - ), -) -@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG]) -def test_run_max_pool_mem_config( - act_shape, - device, - memory_config, - use_program_cache, -): - run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config) - - @pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) @pytest.mark.parametrize( "act_shape", ## NCHW @@ -282,6 +333,8 @@ def test_run_max_pool_mem_config( [4, 1024, 40, 40], [2, 2048, 40, 40], [8, 4096, 10, 16], + # wide yolo kernel + [1, 32768, 10, 10], ) ), ) @@ -290,6 +343,9 @@ def test_run_max_pool_mem_config( ( (2, 2), (3, 3), + (5, 5), + (9, 9), + (13, 13), ), ) @pytest.mark.parametrize( @@ -297,14 +353,26 @@ def test_run_max_pool_mem_config( ( (0, 0), (1, 1), + (2, 2), + (4, 4), + (6, 6), ), ) @pytest.mark.parametrize( "stride", - ((2, 2),), + ( + (1, 1), + (2, 2), + ), ) @pytest.mark.parametrize("dilation", ((1, 1),)) ## default -@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) +@pytest.mark.parametrize( + "dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) def test_run_max_pool_width_shard( act_shape, kernel_size, @@ -359,6 +427,8 @@ def test_run_max_pool_width_shard( [4, 16, 1056, 160], [8, 16, 528, 80], [16, 16, 528, 80], + # wide yolo kernel + [1, 4096, 10, 10], ) ), ) @@ -367,6 +437,9 @@ def test_run_max_pool_width_shard( ( (2, 2), (3, 3), + (5, 5), + (9, 9), + (13, 13), ), ) @pytest.mark.parametrize( @@ -374,14 +447,26 @@ def test_run_max_pool_width_shard( ( (0, 0), (1, 1), + (2, 2), + (4, 4), + (6, 6), ), ) @pytest.mark.parametrize( "stride", - ((2, 2),), + ( + (1, 1), + (2, 2), + ), ) @pytest.mark.parametrize("dilation", ((1, 1),)) ## default -@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) +@pytest.mark.parametrize( + "dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) def test_run_max_pool_block_shard( act_shape, kernel_size, @@ -404,6 +489,26 @@ def test_run_max_pool_block_shard( ) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "act_shape", ## NCHW + ( + ( + [8, 64, 112, 112], + [1, 512, 10, 10], + ) + ), +) +@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG]) +def test_run_max_pool_mem_config( + act_shape, + device, + memory_config, + use_program_cache, +): + run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config) + + @pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) @pytest.mark.parametrize( "act_shape", ## NCHW diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp index a0607bd0a4a..4296c4c042b 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core.cpp @@ -43,24 +43,22 @@ } #endif -template +template inline void reduce_h_fused( const uint32_t in_cb_id, const uint32_t in_scalar_cb_id, - const uint32_t in_ntiles_hwc_block, const uint32_t in_stick_index, const uint32_t out_cb_id) { - constexpr uint32_t num_output_tiles = out_ntiles_c / in_nblocks_c; constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; constexpr uint32_t num_out_rows = 1; - for (uint32_t c_i = 0; c_i < in_nblocks_c; ++ c_i) { + for (uint32_t b_i = 0; b_i < in_nblocks_c; ++ b_i) { cb_reserve_back(out_cb_id, 1); const uint32_t curr_in_cb_id = split_reader ? (in_cb_id + (in_stick_index & 0x1)) : in_cb_id; cb_wait_front(curr_in_cb_id, 1); tile_regs_acquire(); - unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim); - for (uint32_t c_i = 0; c_i < in_ntiles_c / in_nblocks_c; ++c_i) { + unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, num_output_tiles, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim); + for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) { reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */); } cb_pop_front(curr_in_cb_id, 1); @@ -79,11 +77,9 @@ void MAIN { // NOTE: here it is assumed that in_ntiles_hw == 1. General cases not handled yet. constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0); constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1); - constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2); constexpr uint32_t window_size_hw = get_compile_time_arg_val(3); constexpr uint32_t out_h = get_compile_time_arg_val(4); constexpr uint32_t out_w = get_compile_time_arg_val(5); - constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7); constexpr uint32_t split_reader = get_compile_time_arg_val(12); @@ -91,8 +87,6 @@ void MAIN { constexpr uint32_t in_c = get_compile_time_arg_val(14); constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(15); - constexpr uint32_t num_output_tiles = out_ntiles_c; - constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; constexpr uint32_t in_tiled_cb_id = tt::CB::c_intermed0; @@ -103,13 +97,19 @@ void MAIN { constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; constexpr uint32_t num_out_rows = 1; - constexpr uint32_t in_ntiles_hwc_block = in_ntiles_hwc / in_nblocks_c; - tilizeA_B_reduce_init(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, out_cb_id, num_faces_in_tile, window_size_hw); - pack_untilize_dst_init_short(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */ + constexpr uint32_t num_output_tiles = in_ntiles_c / in_nblocks_c; + tilizeA_B_reduce_init( + in_cb_id, + in_scalar_cb_id, + num_output_tiles, + out_cb_id, + num_faces_in_tile, + window_size_hw); + pack_untilize_dst_init_short(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */ cb_wait_front(in_scalar_cb_id, 1); for (uint32_t i = 0; i < nsticks_per_core; ++ i) { - reduce_h_fused(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, i, out_cb_id); + reduce_h_fused(in_cb_id, in_scalar_cb_id, i, out_cb_id); } cb_pop_front(in_scalar_cb_id, 1); } diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp index 90a7d6c0a40..346a8ef8652 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp @@ -53,59 +53,48 @@ inline void print_full_tile(uint32_t cb_id, uint32_t tile_id = 0, bool untilize #endif template < - uint32_t in_ntiles_hw, - uint32_t in_ntiles_c, - uint32_t out_ntiles_c, - uint32_t nblocks, + uint32_t num_output_tiles, bool is_partial_tile, uint32_t split_reader> inline void reduce_h_fused( const uint32_t in_cb_id, const uint32_t in_scalar_cb_id, - const uint32_t num_tiles_for_reduction, const uint32_t in_stick_index, - const uint32_t out_cb_id, const uint32_t unpA_face_r_dim) { - constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks; uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : unpA_face_r_dim < 32 ? 2 : 4; constexpr uint32_t num_out_rows = 1; - for (uint32_t out_elem_i = 0; out_elem_i < nblocks; ++out_elem_i) { - const uint32_t curr_in_cb_id = - split_reader ? (in_cb_id + (in_stick_index * nblocks + out_elem_i) & 0x1) : in_cb_id; - cb_wait_front(curr_in_cb_id, 1); - unpack_tilizeA_B_block( - curr_in_cb_id, - in_scalar_cb_id, - num_tiles_for_reduction, - 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, - num_faces_in_input_tile /* unpack 1 or 2 faces ) */, - unpA_face_r_dim); - for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) { - reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */); - } - cb_pop_front(curr_in_cb_id, 1); + const uint32_t curr_in_cb_id = + split_reader ? (in_cb_id + (in_stick_index & 0x1)) : in_cb_id; + cb_wait_front(curr_in_cb_id, 1); + unpack_tilizeA_B_block( + curr_in_cb_id, + in_scalar_cb_id, + num_output_tiles, + 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, + num_faces_in_input_tile /* unpack 1 or 2 faces ) */, + unpA_face_r_dim); + for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) { + reduce_tile_math(c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */); } + cb_pop_front(curr_in_cb_id, 1); } namespace NAMESPACE { void MAIN { // NOTE: here it is assumed that in_ntiles_hw == 1. General cases not handled yet. - constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0); + constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0); // note ntiles_hw will always be 1 in this kernel, when ntiles_hw > 1 the large kernel is called constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1); - constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2); constexpr uint32_t window_size_hw = get_compile_time_arg_val(3); constexpr uint32_t out_h = get_compile_time_arg_val(4); constexpr uint32_t out_w = get_compile_time_arg_val(5); - constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7); - constexpr uint32_t nblocks = get_compile_time_arg_val(8); constexpr uint32_t split_reader = get_compile_time_arg_val(12); constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(13); constexpr uint32_t in_c = get_compile_time_arg_val(14); + constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(15); constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(16); - constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks; constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; @@ -113,25 +102,19 @@ void MAIN { constexpr uint32_t out_cb_id = tt::CB::c_out0; constexpr uint32_t interm_reduction_cb_id = tt::CB::c_intermed1; + constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; + constexpr bool is_partial_tile = in_c < 32; static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16"); constexpr uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : max_rows_for_reduction < 32 ? 2 : 4; constexpr uint32_t num_faces_in_output_tile = is_partial_tile ? 1 : 2; constexpr uint32_t num_out_rows = 1; - constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; - - constexpr uint32_t num_tiles_for_reduction = - in_ntiles_hwc > MAX_TILES_PER_REDUCTION ? MAX_TILES_PER_REDUCTION : in_ntiles_hwc; - uint32_t num_8_tiles_blocks = 1; - if (num_output_tiles > MAX_TILES_PER_REDUCTION) { - num_8_tiles_blocks = - num_output_tiles / MAX_TILES_PER_REDUCTION; // For now, only pow of 2 number of channels are supported. - } + constexpr uint32_t num_output_tiles = in_ntiles_c / in_nblocks_c; tilizeA_B_reduce_init( in_cb_id, in_scalar_cb_id, - num_tiles_for_reduction, + num_output_tiles, interm_reduction_cb_id, num_faces_in_input_tile, max_rows_for_reduction); @@ -140,27 +123,24 @@ void MAIN { cb_wait_front(in_scalar_cb_id, 1); cb_reserve_back(out_cb_id, 1); for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; ++i) { - for (uint32_t j = 0; j < num_8_tiles_blocks; j++) { + for (uint32_t b_i = 0; b_i < in_nblocks_c; b_i++) { // NOTE: Assuming in_ntiles_hw < 8 for now. // TODO: subblocking to support this. - uint32_t out_write_idx = i * num_8_tiles_blocks + j; - - pack_untilize_dst_init_short( + uint32_t out_write_idx = i * in_nblocks_c + b_i; + pack_untilize_dst_init_short( interm_reduction_cb_id, num_out_rows, num_faces_in_output_tile); cb_reserve_back(interm_reduction_cb_id, 1); for (uint32_t h = 0; h <= interm_reduction_chunks; h++) { tile_regs_acquire(); - reduce_h_fused( + reduce_h_fused( in_cb_id, in_scalar_cb_id, - num_tiles_for_reduction, i, - interm_reduction_cb_id, max_rows_for_reduction); - tile_regs_commit(); tile_regs_wait(); - pack_untilize_dst( + tile_regs_commit(); + pack_untilize_dst( interm_reduction_cb_id, 1 /*out_subblock_h*/, h, @@ -171,24 +151,26 @@ void MAIN { cb_push_back(interm_reduction_cb_id, 1); pack_untilize_uninit(interm_reduction_cb_id); cb_wait_front(interm_reduction_cb_id, 1); - pack_untilize_dst_init_short( + + pack_untilize_dst_init_short( out_cb_id, num_out_rows, num_faces_in_output_tile); tile_regs_acquire(); unpack_tilizeA_B_block( interm_reduction_cb_id, in_scalar_cb_id, - num_tiles_for_reduction, + num_output_tiles, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_input_tile /* unpack 1 or 2 faces ) */, max_rows_for_reduction); - for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) { + for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) { reduce_tile_math(c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */); } - tile_regs_commit(); tile_regs_wait(); - pack_untilize_dst( + tile_regs_commit(); + + pack_untilize_dst( out_cb_id, 1 /*out_subblock_h*/, out_write_idx, @@ -199,7 +181,6 @@ void MAIN { pack_untilize_uninit(out_cb_id); } } - // print_full_tile(out_cb_id); cb_push_back(out_cb_id, 1); cb_pop_front(in_scalar_cb_id, 1); } diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp index c6b0ea9f930..dae7348ab1d 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp @@ -51,14 +51,12 @@ void kernel_main() { // channel size in bytes, multiple of 32 const uint32_t in_nbytes_c = get_compile_time_arg_val(4); - const uint32_t in_nbytes_c_log2 = get_compile_time_arg_val(5); // input tensor height / width / channels const int32_t in_w = get_compile_time_arg_val(6); const uint32_t in_cb_nsticks = get_compile_time_arg_val(7); const uint32_t in_c = get_compile_time_arg_val(8); - const uint32_t nblocks = get_compile_time_arg_val(9); const uint32_t split_reader = get_compile_time_arg_val(10); const uint32_t reader_id = get_compile_time_arg_val(11); @@ -66,14 +64,14 @@ void kernel_main() { // compile time args // value of 1 in bf16 in a uin32_t constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12); - - constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(14); - - // static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2"); + constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13); + constexpr uint32_t in_cb_sz = get_compile_time_arg_val(14); + constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(15); constexpr uint32_t TILE_SIZE = 32 * 32; constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; - constexpr uint32_t MAX_ELE_PER_REDUCTION = 512; + constexpr uint32_t MAX_ELE_PER_REDUCTION = 512; // TILE_WIDTH * 8 * numbytes + constexpr uint32_t ROW_HW = 64; constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0; constexpr uint32_t in_shard_cb_id = tt::CB::c_in2; // local input shard @@ -81,8 +79,6 @@ void kernel_main() { constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; constexpr uint32_t interm_reduction_cb_id = tt::CB::c_intermed1; - constexpr uint32_t ROW_HW = 64; - // minus infinity for bfp16 uint16_t minus_inf = 63487; // Reduce scalar = 1 @@ -90,10 +86,10 @@ void kernel_main() { cb_reserve_back(in_scalar_cb_id, 1); uint32_t bf16_one_u16 = bf16_one_u32 >> 16; + // fill interm buffer with minus_inf + fill_with_val(get_write_ptr(interm_reduction_cb_id), in_cb_sz, minus_inf); // fill 1 row w/ scalar fill_with_val(get_write_ptr(in_scalar_cb_id), ROW_HW, bf16_one_u16); - // fill interm buffer with minus_inf - fill_with_val(get_write_ptr(interm_reduction_cb_id), TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); cb_push_back(in_scalar_cb_id, 1); } @@ -104,53 +100,46 @@ void kernel_main() { uint32_t in_w_padded = in_w + 2 * pad_w; - uint32_t npages_to_reserve = nblocks; - uint32_t num_8_tile_blocks = 1; uint32_t read_bytes = in_nbytes_c; if (in_nbytes_c > MAX_ELE_PER_REDUCTION) { - num_8_tile_blocks = in_nbytes_c / MAX_ELE_PER_REDUCTION; read_bytes = MAX_ELE_PER_REDUCTION; // for now, pow of 2 channels are only supported. } uint32_t counter = reader_id; uint32_t total_elems_to_reduce = window_h * window_w; uint32_t remaining_elems = total_elems_to_reduce % max_rows_for_reduction; while (counter < reader_nindices) { - for (uint32_t j = 0; j < num_8_tile_blocks; j++) { - for (uint32_t i = 0; i < nblocks; ++i) { - uint16_t top_left_local_index = reader_indices_ptr[counter]; - uint32_t h_multiples = 0; - uint32_t processed_rows = 0; - uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id); - uint32_t out_l1_write_addr = out_l1_write_addr_base; - cb_reserve_back(in_cb_id, npages_to_reserve); - // If next is last chunk, fill whole buffer with -inf. - if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) - fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); - for (uint32_t h = 0; h < window_h; ++h, h_multiples += in_w_padded) { - uint32_t stick_offset = top_left_local_index + h_multiples; + for (uint32_t c_i = 0; c_i < in_nblocks_c; c_i++) { + uint16_t top_left_local_index = reader_indices_ptr[counter]; + uint32_t processed_rows = 0; + cb_reserve_back(in_cb_id, 1); + uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id); + uint32_t out_l1_write_addr = out_l1_write_addr_base; + // fill interm buffer with minus_inf if we have only one chunk + if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) + fill_with_val(out_l1_write_addr, in_cb_sz, minus_inf); + for (uint32_t h = 0; h < window_h; ++h) { + for (uint32_t w = 0; w < window_w; w++) { + uint32_t stick_offset = top_left_local_index + w + h * in_w_padded; uint32_t read_offset = - j * MAX_ELE_PER_REDUCTION + in_l1_read_base_addr + (stick_offset << in_nbytes_c_log2); - for (uint32_t w = 0; w < window_w; w++) { - noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, read_bytes); - out_l1_write_addr += read_bytes; - read_offset += in_nbytes_c; - processed_rows++; - if ((processed_rows % max_rows_for_reduction) == 0) { - noc_async_read_barrier(); - cb_push_back(in_cb_id, npages_to_reserve); - out_l1_write_addr_base = get_write_ptr(in_cb_id); - out_l1_write_addr = out_l1_write_addr_base; - cb_reserve_back(in_cb_id, npages_to_reserve); - // If next is last chunk, fill whole buffer with -inf. - if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) - fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); - } + in_l1_read_base_addr + (stick_offset * in_nbytes_c + c_i * MAX_ELE_PER_REDUCTION); + noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, read_bytes); + out_l1_write_addr += read_bytes; + processed_rows++; + if ((processed_rows % max_rows_for_reduction) == 0) { + noc_async_read_barrier(); + cb_push_back(in_cb_id, 1); + cb_reserve_back(in_cb_id, 1); + out_l1_write_addr_base = get_write_ptr(in_cb_id); + out_l1_write_addr = out_l1_write_addr_base; + // If next is last chunk, fill whole buffer with -inf. + if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) + fill_with_val(out_l1_write_addr, in_cb_sz, minus_inf); } } - if (remaining_elems) { - noc_async_read_barrier(); - cb_push_back(in_cb_id, npages_to_reserve); - } + } + if (remaining_elems) { + noc_async_read_barrier(); + cb_push_back(in_cb_id, 1); } } counter++; diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp index c4b9f941c9f..9f692303666 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_v2.cpp @@ -64,8 +64,6 @@ void kernel_main() { constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13); - // static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2"); - constexpr uint32_t TILE_WIDTH = 32; constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0; diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp index d65ae285c31..2556fc53fcb 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_wide.cpp @@ -67,6 +67,7 @@ void kernel_main() { // static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2"); constexpr uint32_t TILE_WIDTH = 32; + constexpr uint32_t MAX_ELE_PER_REDUCTION = 512; // TILE_WIDTH * 8 * numbytes constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0; constexpr uint32_t in_shard_cb_id = tt::CB::c_in2; // local input shard @@ -101,9 +102,9 @@ void kernel_main() { for (uint32_t h = 0; h < window_h; ++ h) { for (uint32_t w = 0; w < window_w; ++ w) { uint32_t stick_offset = top_left_local_index + w + h * in_w_padded; - uint32_t read_offset = in_l1_read_base_addr + (stick_offset * in_nbytes_c + c_i * TILE_WIDTH * 8 * 2); // 2 bytes, max 8 tiles - noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, TILE_WIDTH * 8 * 2); - out_l1_write_addr += TILE_WIDTH * 8 * 2; + uint32_t read_offset = in_l1_read_base_addr + (stick_offset * in_nbytes_c + c_i * MAX_ELE_PER_REDUCTION); // 2 bytes, max 8 tiles + noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, MAX_ELE_PER_REDUCTION); + out_l1_write_addr += MAX_ELE_PER_REDUCTION; } } noc_async_read_barrier(); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp index 56c7465e713..b57e0de77b4 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp @@ -33,7 +33,7 @@ void validate_maxpool(const Tensor& input, const sliding_window::SlidingWindowCo if (in_memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { uint32_t num_shards_c = sliding_window_config.num_cores_c; const tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape(); - TT_FATAL(input_shape[3] % num_shards_c == 0, "For width and block sharding, input channels should be divisible by num_shards"); + TT_FATAL(input_shape[3] % num_shards_c == 0, "For width and block sharding, input channels ({}) should be divisible by num_shards ({})", input_shape[3], num_shards_c); } } diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp index c21873f788d..577632a9d92 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp @@ -62,11 +62,18 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ uint32_t in_ntiles_c = (uint32_t) std::ceil((float) input_shape[3] / num_shards_c / tt::constants::TILE_WIDTH); uint32_t out_ntiles_c = (uint32_t) std::ceil((float) output_shape[3] / num_shards_c / tt::constants::TILE_WIDTH); + + uint32_t max_rows_for_reduction = 16; + // TODO #14588: temporarily disabling 32 row reductions due to issues in large kernels + /* uint32_t max_rows_for_reduction = tt::constants::TILE_HEIGHT; + // For GRAYSKULL, make reduction for 16 rows at a time. + if (device->arch() == tt::ARCH::GRAYSKULL) + max_rows_for_reduction /= 2; */ + // Hardware can do reduction of 8 tiles at a time. // CB sizes can be restricted to this in case input channels are more than 256 to perform reduction iteratively. - constexpr uint32_t MAX_SMALL_KERNEL_SIZE_HW = 16; constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; - const bool is_large_kernel = kernel_size_hw > MAX_SMALL_KERNEL_SIZE_HW; + const bool is_large_kernel = kernel_size_hw > max_rows_for_reduction; const bool is_wide_reduction = in_ntiles_c > MAX_TILES_PER_REDUCTION; TT_FATAL(nblocks == 1, "Multiple blocks not yet supported"); @@ -148,6 +155,9 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ in_cb_sz = (input_shape[3] / num_shards_c * kernel_size_hw_padded) > (tt::constants::TILE_HW * MAX_TILES_PER_REDUCTION) ? (tt::constants::TILE_HW * MAX_TILES_PER_REDUCTION) : input_shape[3] / num_shards_c * kernel_size_hw_padded; + if (is_wide_reduction) { + in_nblocks_c = in_ntiles_c / MAX_TILES_PER_REDUCTION; + } } else { if (is_wide_reduction) { in_cb_sz = MAX_TILES_PER_REDUCTION * tt::constants::TILE_WIDTH * kernel_size_hw_padded; @@ -188,13 +198,26 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ auto in_tiled_cb = tt::tt_metal::CreateCircularBuffer(program, all_cores, in_tiled_cb_config); log_debug(tt::LogOp, "CB {} :: PS = {}, NP = {}", in_tiled_cb_id, in_tiled_cb_pagesize, in_tiled_cb_npages); + + // output of reduce == writer to write + uint32_t out_cb_id = tt::CB::c_out0; // output rows in RM + // after reduction + uint32_t out_cb_pagesize = output.shard_spec().value().shape[1] * out_nbytes / in_nblocks_c; // there is just one row of channels after each reduction (or 1 block of c if its greater than 8 tiles) + uint32_t out_cb_npages = output.shard_spec().value().shape[0] * in_nblocks_c; + CircularBufferConfig cb_out_config = CircularBufferConfig(out_cb_npages * out_cb_pagesize, {{out_cb_id, out_df}}) + .set_page_size(out_cb_id, out_cb_pagesize) + .set_globally_allocated_address(*output.buffer()); + ; + auto cb_out = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_out_config); + log_debug(tt::LogOp, "CB {} :: PS = {}, NP = {}", out_cb_id, out_cb_pagesize, out_cb_npages); + if (is_large_kernel) { uint32_t max_pool_partials_cb_id = tt::CB::c_intermed1; // max_pool partials - uint32_t max_pool_partials_cb_pagesize = in_cb_sz; + uint32_t max_pool_partials_cb_pagesize = std::min(out_cb_pagesize, TILE_SIZE * 8 * out_nbytes); uint32_t max_pool_partials_cb_npages = nblocks; CircularBufferConfig max_pool_partials_cb_config = CircularBufferConfig( - max_pool_partials_cb_npages * max_pool_partials_cb_pagesize, {{max_pool_partials_cb_id, in_df}}) + max_pool_partials_cb_npages * max_pool_partials_cb_pagesize, {{max_pool_partials_cb_id, out_df}}) .set_page_size(max_pool_partials_cb_id, max_pool_partials_cb_pagesize); auto max_pool_partials_cb = tt::tt_metal::CreateCircularBuffer(program, all_cores, max_pool_partials_cb_config); log_debug( @@ -204,19 +227,6 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ max_pool_partials_cb_pagesize, max_pool_partials_cb_npages); } - - // output of reduce == writer to write - uint32_t out_cb_id = tt::CB::c_out0; // output rows in RM - // after reduction - uint32_t out_cb_pagesize = output.shard_spec().value().shape[1] * out_nbytes / in_nblocks_c; // there is just one row of channels after each reduction (or 1 block of c if its greater than 8 tiles) - uint32_t out_cb_npages = output.shard_spec().value().shape[0] * in_nblocks_c; - CircularBufferConfig cb_out_config = CircularBufferConfig(out_cb_npages * out_cb_pagesize, {{out_cb_id, out_df}}) - .set_page_size(out_cb_id, out_cb_pagesize) - .set_globally_allocated_address(*output.buffer()); - ; - auto cb_out = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_out_config); - log_debug(tt::LogOp, "CB {} :: PS = {}, NP = {}", out_cb_id, out_cb_pagesize, out_cb_npages); - TT_FATAL(output.memory_config().is_sharded(), "Output memory config needs to be sharded"); #if 1 @@ -267,10 +277,6 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ } #endif - uint32_t max_rows_for_reduction = tt::constants::TILE_HEIGHT; - /* For GRAYSKULL, make reduction for 16 rows at a time.*/ - if (device->arch() == tt::ARCH::GRAYSKULL) - max_rows_for_reduction /= 2; /** * Reader Kernel: input rows -> input cb */ @@ -292,6 +298,7 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ 0, // split reader id bf16_one_u32, in_nblocks_c, + in_cb_sz, max_rows_for_reduction}; std::vector reader1_ct_args = { @@ -309,6 +316,7 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ 1, // split reader id bf16_one_u32, in_nblocks_c, + in_cb_sz, max_rows_for_reduction}; std::string reader_kernel_fname; @@ -399,7 +407,6 @@ MaxPool2D::MultiCore::cached_program_t MaxPool2D::MultiCore::create(const operat uint32_t out_h = output_shape[1]; uint32_t out_w = output_shape[2]; - bool is_width_sharded = input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED; bool is_block_sharded = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED; auto pad_metadata = sliding_window::generate_pad_metadata(sliding_window_config);