Skip to content

Commit

Permalink
#14249: Fixed bug for width and block sharding with large kernel size…
Browse files Browse the repository at this point in the history
…s and wide reductions
  • Loading branch information
wransom-TT committed Oct 31, 2024
1 parent f2f17c4 commit af60c98
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 191 deletions.
137 changes: 104 additions & 33 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ def run_max_pool(
if shard_scheme != ttnn.TensorMemoryLayout.WIDTH_SHARDED:
if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w:
pytest.skip("Invalid case")
if (kernel_h == 3 and pad_h != 1) or (kernel_h == 2 and pad_h != 0):
pytest.skip("kernel size and padding combination not supported")

if (
(kernel_h == 13 and pad_h != 6)
or (kernel_h == 9 and pad_h != 4)
or (kernel_h == 5 and pad_h != 2)
or (kernel_h == 3 and pad_h != 1)
or (kernel_h == 2 and pad_h != 0)
):
pytest.skip("kernel size and padding combination not supported")

out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1
out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1
Expand All @@ -50,6 +57,8 @@ def run_max_pool(
pytest.skip("This case runs out of memory on Grayskull")
if in_n > 16 and in_c > 64 and dtype == ttnn.bfloat8_b and is_wormhole_b0():
pytest.skip("This case runs out of memory on Wormhole b0")
if stride == (1, 1) and in_n * in_c * in_h * in_w > 1e7:
pytest.skip("This case runs out of memory")

if shard_scheme == ttnn.TensorMemoryLayout.WIDTH_SHARDED:
if in_c < max_cores:
Expand All @@ -64,10 +73,16 @@ def run_max_pool(
pytest.skip("Block sharding requires large enough channels to shard (at least 16 per core)")

torch.manual_seed(0)
torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32)
torch.set_printoptions(precision=3, sci_mode=True, linewidth=500, threshold=10000, edgeitems=32)

## construct the tensor in NCHW shape
act = torch.randn(act_shape, dtype=torch.bfloat16)
# act = torch.empty(act_shape, dtype=torch.bfloat16)
# for n in range(act_shape[0]):
# for c in range(act_shape[1]):
# for h in range(act_shape[2]):
# for w in range(act_shape[3]):
# act[n, c, h, w] = h * in_w + w
# act = torch.zeros(act_shape, dtype=torch.bfloat16)
# act = torch.ones(act_shape, dtype=torch.bfloat16)
# act = torch.arange(0, volume(act_shape), dtype=torch.bfloat16).reshape(act_shape)
Expand Down Expand Up @@ -150,7 +165,12 @@ def run_max_pool(
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])

output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W
passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch)

pcc_thresh = 0.9999
if dtype == ttnn.bfloat8_b:
pcc_thresh = 0.9997

passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch, pcc_thresh)

logger.debug(f"Passing: {passing}, PCC: {pcc}")

Expand Down Expand Up @@ -181,7 +201,7 @@ def run_max_pool(
[1, 64, 112, 112],
[4, 64, 112, 112],
[8, 64, 112, 112],
[16, 64, 112, 112],
[16, 64, 112, 112], # oom with stride (1,1)
# [20, 64, 112, 112], ## oom
## hpr shapes
[8, 32, 132, 20],
Expand All @@ -196,15 +216,15 @@ def run_max_pool(
# [64, 32, 264, 40], ## oom
# [128, 32, 264, 40], ## oom
# [256, 32, 264, 40], ## oom
[4, 16, 1056, 160],
[4, 16, 1056, 160], # oom with stride (1,1)
# [8, 16, 1056, 160], ## oom
# [16, 16, 1056, 160], ## oom
# [32, 16, 1056, 160], ## oom
# [64, 16, 1056, 160], ## oom
# [128, 16, 1056, 160], ## oom
# [256, 16, 1056, 160], ## oom
[8, 16, 528, 80],
[16, 16, 528, 80],
[16, 16, 528, 80], # oom with stride (1,1)
# [32, 16, 528, 80], ## oom
# [64, 16, 528, 80], ## oom
# [128, 16, 528, 80], ## oom
Expand All @@ -213,6 +233,8 @@ def run_max_pool(
[1, 256, 56, 56],
[1, 512, 28, 28],
[1, 512, 14, 14],
# wide yolo kernel
[1, 512, 10, 10],
)
),
)
Expand All @@ -221,21 +243,36 @@ def run_max_pool(
(
(2, 2),
(3, 3),
(5, 5),
(9, 9),
(13, 13),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
(2, 2),
(4, 4),
(6, 6),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
(
(1, 1),
(2, 2),
),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
"dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
def test_run_max_pool(
act_shape,
kernel_size,
Expand All @@ -249,26 +286,6 @@ def test_run_max_pool(
run_max_pool(act_shape, kernel_size, padding, stride, dilation, device, dtype)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
(
(
[8, 64, 112, 112],
[1, 512, 10, 10],
)
),
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_run_max_pool_mem_config(
act_shape,
device,
memory_config,
use_program_cache,
):
run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
Expand All @@ -282,6 +299,8 @@ def test_run_max_pool_mem_config(
[4, 1024, 40, 40],
[2, 2048, 40, 40],
[8, 4096, 10, 16],
# wide yolo kernel
[1, 32768, 10, 10],
)
),
)
Expand All @@ -290,21 +309,36 @@ def test_run_max_pool_mem_config(
(
(2, 2),
(3, 3),
(5, 5),
(9, 9),
(13, 13),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
(2, 2),
(4, 4),
(6, 6),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
(
(1, 1),
(2, 2),
),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
"dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
def test_run_max_pool_width_shard(
act_shape,
kernel_size,
Expand Down Expand Up @@ -359,6 +393,8 @@ def test_run_max_pool_width_shard(
[4, 16, 1056, 160],
[8, 16, 528, 80],
[16, 16, 528, 80],
# wide yolo kernel
[1, 4096, 10, 10],
)
),
)
Expand All @@ -367,21 +403,36 @@ def test_run_max_pool_width_shard(
(
(2, 2),
(3, 3),
(5, 5),
(9, 9),
(13, 13),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
(2, 2),
(4, 4),
(6, 6),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
(
(1, 1),
(2, 2),
),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
@pytest.mark.parametrize(
"dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
def test_run_max_pool_block_shard(
act_shape,
kernel_size,
Expand All @@ -404,6 +455,26 @@ def test_run_max_pool_block_shard(
)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
(
(
[8, 64, 112, 112],
[1, 512, 10, 10],
)
),
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_run_max_pool_mem_config(
act_shape,
device,
memory_config,
use_program_cache,
):
run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,22 @@
}
#endif

template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c>
template<uint32_t num_output_tiles, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c>
inline void reduce_h_fused(
const uint32_t in_cb_id,
const uint32_t in_scalar_cb_id,
const uint32_t in_ntiles_hwc_block,
const uint32_t in_stick_index,
const uint32_t out_cb_id) {

constexpr uint32_t num_output_tiles = out_ntiles_c / in_nblocks_c;
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;
for (uint32_t c_i = 0; c_i < in_nblocks_c; ++ c_i) {
for (uint32_t b_i = 0; b_i < in_nblocks_c; ++ b_i) {
cb_reserve_back(out_cb_id, 1);
const uint32_t curr_in_cb_id = split_reader ? (in_cb_id + (in_stick_index & 0x1)) : in_cb_id;
cb_wait_front(curr_in_cb_id, 1);
tile_regs_acquire();
unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < in_ntiles_c / in_nblocks_c; ++c_i) {
unpack_tilizeA_B_block(curr_in_cb_id, in_scalar_cb_id, num_output_tiles, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, num_faces_in_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) {
reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */);
}
cb_pop_front(curr_in_cb_id, 1);
Expand All @@ -79,20 +77,16 @@ void MAIN {
// NOTE: here it is assumed that in_ntiles_hw == 1. General cases not handled yet.
constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0);
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1);
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2);
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3);
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);

constexpr uint32_t split_reader = get_compile_time_arg_val(12);

constexpr uint32_t nsticks_per_core = get_compile_time_arg_val(13);
constexpr uint32_t in_c = get_compile_time_arg_val(14);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(15);

constexpr uint32_t num_output_tiles = out_ntiles_c;

constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;
constexpr uint32_t in_tiled_cb_id = tt::CB::c_intermed0;
Expand All @@ -103,13 +97,19 @@ void MAIN {
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;

constexpr uint32_t in_ntiles_hwc_block = in_ntiles_hwc / in_nblocks_c;
tilizeA_B_reduce_init<true>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, out_cb_id, num_faces_in_tile, window_size_hw);
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */
constexpr uint32_t num_output_tiles = in_ntiles_c / in_nblocks_c;
tilizeA_B_reduce_init<true>(
in_cb_id,
in_scalar_cb_id,
num_output_tiles,
out_cb_id,
num_faces_in_tile,
window_size_hw);
pack_untilize_dst_init_short<in_ntiles_c>(out_cb_id, num_out_rows, num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */

cb_wait_front(in_scalar_cb_id, 1);
for (uint32_t i = 0; i < nsticks_per_core; ++ i) {
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, is_partial_tile, split_reader, window_size_hw, in_nblocks_c>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc_block, i, out_cb_id);
reduce_h_fused<num_output_tiles, is_partial_tile, split_reader, window_size_hw, in_nblocks_c>(in_cb_id, in_scalar_cb_id, i, out_cb_id);
}
cb_pop_front(in_scalar_cb_id, 1);
}
Expand Down
Loading

0 comments on commit af60c98

Please sign in to comment.