Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase number of cores used to calculate upsample for YOLO and Stable diffusion #16351

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ def __init__(self, input_height, input_width, in_channels, scale_factor):

## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape
ncores = None
grid_sizes = {1024: (8, 5), 256: (8, 8), 64: (4, 8)}
max_grid_size = grid_sizes[input_height * input_width]

max_nshards_h = min(batch_size * input_height, max_grid_size[0]) ## height along NHW
max_grid_size = (8, 8)
max_nshards_h = min(batch_size * input_height * input_width, max_grid_size[0]) ## height along NHW
max_nshards_w = min(in_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
Expand Down
74 changes: 74 additions & 0 deletions models/demos/yolov4/ttnn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import math
from typing import Tuple
import ttnn


Expand Down Expand Up @@ -121,3 +123,75 @@ def __call__(self, device, input_tensor):
return_weights_and_bias=True,
)
return output_tensor


class Upsample:
def __init__(self, input_params, scale_factor, mode="nearest") -> None:
self.batch_size = input_params[0]
self.input_height = input_params[1]
self.input_width = input_params[2]
self.input_channels = input_params[3]
self.scale_h = scale_factor[0]
self.scale_w = scale_factor[1]
self.mode = mode

# helper functions for upsample for block sharded inputs
def determine_num_cores_for_upsample(
self, batch_size: int, height: int, width: int, num_channels: int, max_grid_size: Tuple[int, int]
) -> Tuple[int, int]:
max_nshards_h = min(
batch_size * height * width, max_grid_size[0]
) ## height along NHW (N: batch size, H: height, W: width)
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C (number of channels)
## find nshards_h along NHW
nshards_h = max_nshards_h
while nshards_h > 0:
if batch_size * height % nshards_h == 0:
break
nshards_h -= 1
## find nshards_w along C
nshards_w = max_nshards_w
while nshards_w > 0:
## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B
if num_channels % nshards_w == 0 and math.ceil(num_channels * 2 / nshards_w) % 32 == 0:
break
nshards_w -= 1
if nshards_w == 0 or nshards_h == 0:
raise ValueError(f"nshards_h or nshards_w is 0: nshards_h={nshards_h}, nshards_w={nshards_w}")
return [nshards_h, nshards_w]

def get_core_grid_from_num_cores_for_upsample(self, num_cores: Tuple[int, int], max_grid_size: Tuple[int, int]) -> ttnn.CoreRangeSet: # type: ignore
ncores_h, ncores_w = num_cores
assert ncores_h <= max_grid_size[0]
assert ncores_w <= max_grid_size[1]
return ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(ncores_w - 1, ncores_h - 1),
)
}
)

def __call__(self, device, input_tensor):
device_grid = device.compute_with_storage_grid_size()
max_grid_size = [device_grid.y, device_grid.x]
num_cores = self.determine_num_cores_for_upsample(
self.batch_size, self.input_height, self.input_width, self.input_channels, max_grid_size
)
shard_grid = self.get_core_grid_from_num_cores_for_upsample(num_cores, max_grid_size)
shard_height = math.ceil(self.input_height * self.input_width / num_cores[0])
shard_width = math.ceil(self.input_channels / num_cores[1])
shard_spec = ttnn.ShardSpec(shard_grid, (shard_height, shard_width), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config)
out_shard_spec = ttnn.ShardSpec(
shard_grid,
(shard_height * self.scale_h * self.scale_w, shard_width),
ttnn.ShardOrientation.ROW_MAJOR,
False,
)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, out_shard_spec
)
return ttnn.upsample(output_tensor, (self.scale_h, self.scale_w), memory_config=out_sharded_mem_config)
47 changes: 11 additions & 36 deletions models/demos/yolov4/ttnn/neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import ttnn
from models.demos.yolov4.ttnn.common import Conv
from models.demos.yolov4.ttnn.common import Conv, Upsample
from tt_lib.fallback_ops import fallback_ops


Expand Down Expand Up @@ -67,6 +67,10 @@ def __init__(self, model) -> None:
width_sharding=True,
deallocate=False,
)
self.upsample1 = Upsample(
[1, 10, 10, 256],
[2, 2],
)
self.conv7_2 = Conv(
torch_model,
"neek.conv8",
Expand Down Expand Up @@ -126,6 +130,10 @@ def __init__(self, model) -> None:
enable_split_reader=True,
enable_act_double_buffer=True,
)
self.upsample2 = Upsample(
[1, 20, 20, 128],
[2, 2],
)
self.conv9_2 = Conv(
torch_model,
"neek.conv15",
Expand Down Expand Up @@ -244,25 +252,8 @@ def __call__(self, device, input_tensor):
),
memory_config=ttnn.L1_MEMORY_CONFIG,
)

output_tensor = ttnn.reshape(output_tensor, (1, 10, 10, 256))
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 4),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (20, 32), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config)
shard_spec = ttnn.ShardSpec(shard_grid, (80, 32), ttnn.ShardOrientation.ROW_MAJOR, False)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)

output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2), memory_config=out_sharded_mem_config)
output_tensor_upsample_1 = self.upsample1(device, output_tensor)
output_tensor_upsample_1 = ttnn.sharded_to_interleaved(output_tensor_upsample_1, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_1 = ttnn.reshape(output_tensor_upsample_1, (1, 1, 400, 256))
output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT)
Expand Down Expand Up @@ -320,23 +311,7 @@ def __call__(self, device, input_tensor):
)

output_tensor = ttnn.reshape(output_tensor, (1, 20, 20, 128))
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(7, 4),
),
}
)
shard_spec = ttnn.ShardSpec(shard_grid, (80, 16), ttnn.ShardOrientation.ROW_MAJOR, False)
in_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=in_sharded_mem_config)
shard_spec = ttnn.ShardSpec(shard_grid, (80 * 4, 16), ttnn.ShardOrientation.ROW_MAJOR, False)
out_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)

output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2), memory_config=out_sharded_mem_config)
output_tensor_upsample_2 = self.upsample2(device, output_tensor)
output_tensor_upsample_2 = ttnn.sharded_to_interleaved(output_tensor_upsample_2, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_2 = ttnn.reshape(output_tensor_upsample_2, (1, 1, 1600, 128))
output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def create_unet_model_parameters(
for key in parameters.keys():
parameters[key].module = getattr(model, key)

parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 16 * 32}
parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32}
parameters.c1["use_split_reader"] = True
parameters.c1["use_activation_double_buffer"] = True
parameters.c1["input_channels_alignment"] = 16
parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 16 * 32}
parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32}
parameters.c1_2["use_split_reader"] = True
parameters.c1_2["use_activation_double_buffer"] = True
parameters.c1_2["input_channels_alignment"] = 16
Expand Down Expand Up @@ -136,7 +136,6 @@ def create_unet_model_parameters(
parameters.c8_3["use_split_reader"] = True
parameters.c8_3["input_channels_alignment"] = 16

parameters.output_layer["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 16 * 32}
parameters.output_layer["use_activation_double_buffer"] = True
parameters.output_layer["use_split_reader"] = True
parameters.output_layer["input_channels_alignment"] = 16
Expand Down
29 changes: 21 additions & 8 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch.nn as nn
import ttnn
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from models.utility_functions import skip_for_grayskull, skip_for_blackhole, is_grayskull
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


Expand Down Expand Up @@ -109,12 +109,25 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w):
[1, 64, 132, 10],
[1, 32, 8, 8],
[2, 640, 32, 32],
# some random shapes
[1, 32, 5, 4],
[3, 32, 4, 4],
[5, 64, 5, 5],
[1, 128, 5, 8],
[1, 32, 5, 4],
[1, 64, 128, 17],
[1, 64, 132, 19],
],
)
@pytest.mark.parametrize("scale_h", [2])
@pytest.mark.parametrize("scale_w", [2])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize("scale_h", [2, 3])
@pytest.mark.parametrize("scale_w", [2, 3])
@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy):
@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR])
def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy, shard_orientation):
if is_grayskull() and (scale_h > 2 or scale_w > 2):
pytest.skip("Skipping test because it won't fit in L1!")

## input shape is N C H W
batch_size, num_channels, height, width = input_shape
torch.manual_seed(0)
Expand All @@ -136,15 +149,15 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
max_grid_size = (device_grid.y, device_grid.x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1])
max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height % nshards == 0:
if batch_size * height * width % nshards == 0:
break
nshards -= 1
ncores = nshards
elif shard_strategy == ttnn.ShardStrategy.BLOCK:
max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW
max_nshards_h = min(batch_size * height * width, max_grid_size[0]) ## height along NHW
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
Expand Down Expand Up @@ -177,7 +190,6 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
# )

shard_grid = get_shard_grid_from_num_cores(device, ncores)
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED
Expand Down Expand Up @@ -351,6 +363,7 @@ def test_bilinear_multi_core(

## compare the results
torch_result = torch_result.permute(0, 2, 3, 1)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999)
allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1)
logger.info(pcc_msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,48 @@

void kernel_main() {
uint32_t stick_nbytes = get_arg_val<uint32_t>(0);
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1);
uint32_t in_nsticks_per_core = get_arg_val<uint32_t>(1);
uint32_t scale_h = get_arg_val<uint32_t>(2);
uint32_t scale_w = get_arg_val<uint32_t>(3);
uint32_t in_w = get_arg_val<uint32_t>(4);
uint32_t out_w = get_arg_val<uint32_t>(5);

constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = get_compile_time_arg_val(1);
constexpr uint32_t is_reader = get_compile_time_arg_val(2);

uint32_t in_image_row_nbytes = in_w * stick_nbytes;
uint32_t out_image_row_nbytes = out_w * stick_nbytes;
uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2;
uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2;
uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core;
uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes;
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes;

cb_reserve_back(out_cb_id, out_w);

// assuming shard begins with a new row. TODO: generalize?
for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) {
uint32_t l1_write_addr_image_row_start = l1_write_addr;
for (uint32_t i = 0; i < in_w; ++i) {
constexpr uint32_t config_cb_id = get_compile_time_arg_val(3);

uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2;
uint32_t out_nsticks_per_core = reader_nsticks_per_core * scale_h * scale_w;
uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core;
uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id);
uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * scale_w * stick_nbytes;

uint32_t config_l1_addr = get_read_ptr(config_cb_id);
volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(config_l1_addr);

uint32_t reader_idx = 0;
if constexpr (!is_reader) {
/* For each input stick there are 2 entries in config cb {{core_coords.x, core_coords.y}, stick_offset(in
* input_cb)} so multiply input image_row_begin with (2 * scale_h) */
reader_idx = (2 * scale_h) * image_row_begin;
}
cb_reserve_back(out_cb_id, out_nsticks_per_core);

for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) {
for (uint32_t sh = 0; sh < scale_h; sh++) {
uint16_t cores = config_data[reader_idx++];
uint16_t corey = cores & 0xFF;
uint16_t corex = cores >> 8;
uint16_t offset = config_data[reader_idx++];
uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes);
// replicate stick scale_w times.
for (uint32_t sw = 0; sw < scale_w; ++sw) {
// replicate stick scale_w times.
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_read_addr);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes);
}
for (uint32_t sw = 0; sw < scale_w; sw++) {
noc_async_read(src_remote_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;
}
l1_read_addr += stick_nbytes;
}

// Duplicate the whole image row in one shot
if constexpr (is_reader) {
uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start);
noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes);
} else {
uint64_t dst_noc_addr = get_noc_addr(l1_write_addr);
noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes);
}
l1_write_addr += out_image_row_nbytes;
}

cb_push_back(out_cb_id, out_w);

noc_async_write_barrier();
noc_async_read_barrier();
cb_push_back(out_cb_id, out_nsticks_per_core);
}
Loading
Loading