Skip to content

Commit

Permalink
#15205: Remove restriction of input_nsticks_per_core % w == 0
Browse files Browse the repository at this point in the history
### Problem description
Currently, whole input row is processed per core which inefficient since
other cores could be idle

### What's changed
Distribute work to all possible cores.

### Checklist
- [X] Post commit CI passes
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12689604385)
- [X] Nightly fast dispatch
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12689606390)
- [X] Model regression CI testing passes
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12689608961)
- [X] Device performance regression CI testing passes (if applicable)
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12689611247)
- [X] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12689614322)

---------

Signed-off-by: Nilaykumar Patel <nkpatel@tenstorrent.com>
  • Loading branch information
nkpatel-tt authored Jan 10, 2025
1 parent e0f11a7 commit a94c89e
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, groups, expected_device_perf_fps",
((1, 2, 1053.0),),
((1, 2, 1040.0),),
)
def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float):
command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]"
Expand Down
29 changes: 21 additions & 8 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch.nn as nn
import ttnn
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from models.utility_functions import skip_for_grayskull, skip_for_blackhole, is_grayskull
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


Expand Down Expand Up @@ -109,12 +109,25 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w):
[1, 64, 132, 10],
[1, 32, 8, 8],
[2, 640, 32, 32],
# some random shapes
[1, 32, 5, 4],
[3, 32, 4, 4],
[5, 64, 5, 5],
[1, 128, 5, 8],
[1, 32, 5, 4],
[1, 64, 128, 17],
[1, 64, 132, 19],
],
)
@pytest.mark.parametrize("scale_h", [2])
@pytest.mark.parametrize("scale_w", [2])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize("scale_h", [2, 3])
@pytest.mark.parametrize("scale_w", [2, 3])
@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy):
@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy, shard_orientation):
if is_grayskull() and (scale_h > 2 or scale_w > 2):
pytest.skip("Skipping test because it won't fit in L1!")

## input shape is N C H W
batch_size, num_channels, height, width = input_shape
torch.manual_seed(0)
Expand All @@ -136,15 +149,15 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
max_grid_size = (device_grid.y, device_grid.x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1])
max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height % nshards == 0:
if batch_size * height * width % nshards == 0:
break
nshards -= 1
ncores = nshards
elif shard_strategy == ttnn.ShardStrategy.BLOCK:
max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW
max_nshards_h = min(batch_size * height * width, max_grid_size[0]) ## height along NHW
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
Expand Down Expand Up @@ -177,7 +190,6 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
# )

shard_grid = get_shard_grid_from_num_cores(device, ncores)
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED
Expand Down Expand Up @@ -351,6 +363,7 @@ def test_bilinear_multi_core(

## compare the results
torch_result = torch_result.permute(0, 2, 3, 1)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999)
allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1)
logger.info(pcc_msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,48 @@

void kernel_main() {
uint32_t stick_nbytes = get_arg_val<uint32_t>(0);
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1);
uint32_t in_nsticks_per_core = get_arg_val<uint32_t>(1);
uint32_t scale_h = get_arg_val<uint32_t>(2);
uint32_t scale_w = get_arg_val<uint32_t>(3);
uint32_t in_w = get_arg_val<uint32_t>(4);
uint32_t out_w = get_arg_val<uint32_t>(5);

constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = get_compile_time_arg_val(1);
constexpr uint32_t is_reader = get_compile_time_arg_val(2);

uint32_t in_image_row_nbytes = in_w * stick_nbytes;
uint32_t out_image_row_nbytes = out_w * stick_nbytes;
uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2;
uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2;
uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core;
uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes;
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes;

cb_reserve_back(out_cb_id, out_w);

// assuming shard begins with a new row. TODO: generalize?
for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) {
uint32_t l1_write_addr_image_row_start = l1_write_addr;
for (uint32_t i = 0; i < in_w; ++i) {
constexpr uint32_t config_cb_id = get_compile_time_arg_val(3);

uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2;
uint32_t out_nsticks_per_core = reader_nsticks_per_core * scale_h * scale_w;
uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core;
uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id);
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * scale_w * stick_nbytes;

uint32_t config_l1_addr = get_read_ptr(config_cb_id);
volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(config_l1_addr);

uint32_t reader_idx = 0;
if constexpr (!is_reader) {
/* For each input stick there are 2 entries in config cb {{core_coords.x, core_coords.y}, stick_offset(in
* input_cb)} so multiply input image_row_begin with (2 * scale_h) */
reader_idx = (2 * scale_h) * image_row_begin;
}
cb_reserve_back(out_cb_id, out_nsticks_per_core);

for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) {
for (uint32_t sh = 0; sh < scale_h; sh++) {
uint16_t cores = config_data[reader_idx++];
uint16_t corey = cores & 0xFF;
uint16_t corex = cores >> 8;
uint16_t offset = config_data[reader_idx++];
uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes);
// replicate stick scale_w times.
for (uint32_t sw = 0; sw < scale_w; ++sw) {
// replicate stick scale_w times.
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_read_addr);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes);
}
for (uint32_t sw = 0; sw < scale_w; sw++) {
noc_async_read(src_remote_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;
}
l1_read_addr += stick_nbytes;
}

// Duplicate the whole image row in one shot
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start);
noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes);
}
l1_write_addr += out_image_row_nbytes;
}

cb_push_back(out_cb_id, out_w);

noc_async_write_barrier();
noc_async_read_barrier();
cb_push_back(out_cb_id, out_nsticks_per_core);
}
Loading

0 comments on commit a94c89e

Please sign in to comment.