Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MaxPool Block / Width Sharding with Large Kernels / Wide Reductions #14531

Merged
merged 11 commits into from
Nov 8, 2024
167 changes: 136 additions & 31 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import math

from models.utility_functions import is_wormhole_b0
from models.utility_functions import is_wormhole_b0, is_grayskull, is_x2_harvested
from tests.ttnn.utils_for_testing import assert_with_pcc

import ttnn
Expand All @@ -34,8 +34,15 @@ def run_max_pool(
if shard_scheme != ttnn.TensorMemoryLayout.WIDTH_SHARDED:
if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w:
pytest.skip("Invalid case")
if (kernel_h == 3 and pad_h != 1) or (kernel_h == 2 and pad_h != 0):
pytest.skip("kernel size and padding combination not supported")

if (
(kernel_h == 13 and pad_h != 6)
or (kernel_h == 9 and pad_h != 4)
or (kernel_h == 5 and pad_h != 2)
or (kernel_h == 3 and pad_h != 1)
or (kernel_h == 2 and pad_h != 0)
):
pytest.skip("kernel size and padding combination not supported")
mywoodstock marked this conversation as resolved.
Show resolved Hide resolved

out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1
out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1
Expand All @@ -47,15 +54,57 @@ def run_max_pool(
if in_c % 16 != 0:
pytest.skip("Current maxpool writer needs nchannels to be multiple of 16!")
if in_c == 16 and dtype == ttnn.bfloat8_b and in_n * in_h * in_w > 600000:
pytest.skip("This case runs out of memory on Grayskull")
pytest.skip("This case runs out of memory")
if in_n > 16 and in_c > 64 and dtype == ttnn.bfloat8_b and is_wormhole_b0():
pytest.skip("This case runs out of memory on Wormhole b0")
if (
stride == (1, 1)
and (act_shape == [16, 64, 112, 112] or act_shape == [4, 16, 1056, 160] or act_shape == [16, 16, 528, 80])
and is_wormhole_b0()
):
pytest.skip("This case runs out of memory on Wormhole b0")
if stride == (1, 1) and act_shape == [8, 16, 528, 80] and is_grayskull():
pytest.skip("This case runs out of memory on Grayskull")
if kernel_h > 3 and kernel_w > 3 and act_shape == [16, 64, 112, 112] and is_grayskull():
pytest.skip("This case runs out of memory on Grayskull")
if kernel_size == (13, 13) and act_shape == [128, 32, 132, 20] and is_grayskull():
pytest.skip("This case runs out of memory on Grayskull")
if kernel_h > 5 and kernel_w > 5 and act_shape == [16, 64, 112, 112] and is_x2_harvested(device):
pytest.skip("This case runs out of memory on Wormhole X2")
if stride == (1, 1) and act_shape == [128, 32, 132, 20] and is_x2_harvested(device):
pytest.skip("This case runs out of memory on Wormhole X2")
if stride == (1, 1) and kernel_size == (13, 13) and act_shape == [32, 32, 264, 40] and is_x2_harvested(device):
pytest.skip("This case runs out of memory on Wormhole X2")
if (
dtype == ttnn.bfloat8_b
and (act_shape == [4, 16, 1056, 160] or act_shape == [16, 16, 528, 80])
and is_x2_harvested(device)
):
pytest.skip("This case runs out of memory on Wormhole X2")

if shard_scheme == ttnn.TensorMemoryLayout.WIDTH_SHARDED:
if in_c < max_cores:
pytest.skip("Width sharding requires channles >= cores")
if in_c / max_cores < 16:
pytest.skip("Width sharding requires large enough channels to shard (at least 16 per core)")
if (
kernel_size == (13, 13)
and (act_shape == [8, 4096, 10, 16] or act_shape == [1, 32768, 10, 10])
and is_grayskull()
):
pytest.skip("This case runs out of memory on Grayskull")
if (
stride == (1, 1)
and kernel_h > 5
and kernel_w > 5
and (act_shape == [4, 1024, 40, 40] or act_shape == [2, 2048, 40, 40] or act_shape == [8, 4096, 10, 16])
and is_x2_harvested(device)
):
pytest.skip("This case runs out of memory on Wormhole X2")
if kernel_h > 5 and kernel_w > 5 and act_shape == [8, 4096, 10, 16] and is_x2_harvested(device):
pytest.skip("This case runs out of memory on Wormhole X2")
if kernel_size == (13, 13) and act_shape == [1, 32768, 10, 10] and is_x2_harvested(device):
pytest.skip("This case runs out of memory on Wormhole X2")

if shard_scheme == ttnn.TensorMemoryLayout.BLOCK_SHARDED:
if in_c < cores_x:
Expand Down Expand Up @@ -150,7 +199,12 @@ def run_max_pool(
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])

output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W
passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch)

pcc_thresh = 1.0
if dtype == ttnn.bfloat8_b:
pcc_thresh = 0.9997

passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch, pcc_thresh)

logger.debug(f"Passing: {passing}, PCC: {pcc}")

Expand Down Expand Up @@ -213,6 +267,8 @@ def run_max_pool(
[1, 256, 56, 56],
[1, 512, 28, 28],
[1, 512, 14, 14],
# wide yolo kernel
[1, 512, 10, 10],
)
),
)
Expand All @@ -221,21 +277,36 @@ def run_max_pool(
(
(2, 2),
(3, 3),
(5, 5),
(9, 9),
(13, 13),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
(2, 2),
(4, 4),
(6, 6),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
(
(1, 1),
(2, 2),
),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
"dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
def test_run_max_pool(
act_shape,
kernel_size,
Expand All @@ -249,26 +320,6 @@ def test_run_max_pool(
run_max_pool(act_shape, kernel_size, padding, stride, dilation, device, dtype)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
(
(
[8, 64, 112, 112],
[1, 512, 10, 10],
)
),
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_run_max_pool_mem_config(
act_shape,
device,
memory_config,
use_program_cache,
):
run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
Expand All @@ -282,6 +333,8 @@ def test_run_max_pool_mem_config(
[4, 1024, 40, 40],
[2, 2048, 40, 40],
[8, 4096, 10, 16],
# wide yolo kernel
[1, 32768, 10, 10],
)
),
)
Expand All @@ -290,21 +343,36 @@ def test_run_max_pool_mem_config(
(
(2, 2),
(3, 3),
(5, 5),
(9, 9),
(13, 13),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
(2, 2),
(4, 4),
(6, 6),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
(
(1, 1),
(2, 2),
),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
"dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
def test_run_max_pool_width_shard(
act_shape,
kernel_size,
Expand Down Expand Up @@ -359,6 +427,8 @@ def test_run_max_pool_width_shard(
[4, 16, 1056, 160],
[8, 16, 528, 80],
[16, 16, 528, 80],
# wide yolo kernel
[1, 4096, 10, 10],
)
),
)
Expand All @@ -367,21 +437,36 @@ def test_run_max_pool_width_shard(
(
(2, 2),
(3, 3),
(5, 5),
(9, 9),
(13, 13),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
(2, 2),
(4, 4),
(6, 6),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
(
(1, 1),
(2, 2),
),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
"dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
def test_run_max_pool_block_shard(
act_shape,
kernel_size,
Expand All @@ -404,6 +489,26 @@ def test_run_max_pool_block_shard(
)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
(
(
[8, 64, 112, 112],
[1, 512, 10, 10],
)
),
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_run_max_pool_mem_config(
act_shape,
device,
memory_config,
use_program_cache,
):
run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,22 @@
}
#endif

template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c>
template<uint32_t num_output_tiles, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c>
inline void reduce_h_fused(
const uint32_t in_cb_id,
const uint32_t in_scalar_cb_id,
const uint32_t in_ntiles_hwc_block,
const uint32_t in_stick_index,
const uint32_t out_cb_id) {

constexpr uint32_t num_output_tiles = out_ntiles_c / in_nblocks_c;
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;
for (uint32_t c_i = 0; c_i < in_nblocks_c; ++ c_i) {
for (uint32_t b_i = 0; b_i < in_nblocks_c; ++ b_i) {
cb_reserve_back(out_cb_id, 1);
const uint32_t curr_in_cb_id = split_reader ? (in_cb_id + (in_stick_index & 0x1)) : in_cb_id;
cb_wait_front(curr_in_cb_id, 1);
tile_regs_acquire();
unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < in_ntiles_c / in_nblocks_c; ++c_i) {
unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, num_output_tiles, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) {
reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */);
}
cb_pop_front(curr_in_cb_id, 1);
Expand All @@ -79,20 +77,16 @@ void MAIN {
// NOTE: here it is assumed that in_ntiles_hw == 1. General cases not handled yet.
constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0);
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1);
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2);
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3);
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);

constexpr uint32_t split_reader = get_compile_time_arg_val(12);

constexpr uint32_t nsticks_per_core = get_compile_time_arg_val(13);
constexpr uint32_t in_c = get_compile_time_arg_val(14);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(15);

constexpr uint32_t num_output_tiles = out_ntiles_c;

constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;
constexpr uint32_t in_tiled_cb_id = tt::CB::c_intermed0;
Expand All @@ -103,13 +97,19 @@ void MAIN {
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;

constexpr uint32_t in_ntiles_hwc_block = in_ntiles_hwc / in_nblocks_c;
tilizeA_B_reduce_init<true>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, out_cb_id, num_faces_in_tile, window_size_hw);
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */
constexpr uint32_t num_output_tiles = in_ntiles_c / in_nblocks_c;
mywoodstock marked this conversation as resolved.
Show resolved Hide resolved
tilizeA_B_reduce_init<true>(
in_cb_id,
in_scalar_cb_id,
num_output_tiles,
out_cb_id,
num_faces_in_tile,
window_size_hw);
pack_untilize_dst_init_short<in_ntiles_c>(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */

cb_wait_front(in_scalar_cb_id, 1);
for (uint32_t i = 0; i < nsticks_per_core; ++ i) {
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, is_partial_tile, split_reader, window_size_hw, in_nblocks_c>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, i, out_cb_id);
reduce_h_fused<num_output_tiles, is_partial_tile, split_reader, window_size_hw, in_nblocks_c>(in_cb_id, in_scalar_cb_id, i, out_cb_id);
}
cb_pop_front(in_scalar_cb_id, 1);
}
Expand Down
Loading
Loading