Skip to content

Commit

Permalink
#0: PR cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wransom-TT committed Oct 31, 2024
1 parent af60c98 commit f625a25
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 63 deletions.
8 changes: 1 addition & 7 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,10 @@ def run_max_pool(
pytest.skip("Block sharding requires large enough channels to shard (at least 16 per core)")

torch.manual_seed(0)
torch.set_printoptions(precision=3, sci_mode=True, linewidth=500, threshold=10000, edgeitems=32)
torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32)

## construct the tensor in NCHW shape
act = torch.randn(act_shape, dtype=torch.bfloat16)
# act = torch.empty(act_shape, dtype=torch.bfloat16)
# for n in range(act_shape[0]):
# for c in range(act_shape[1]):
# for h in range(act_shape[2]):
# for w in range(act_shape[3]):
# act[n, c, h, w] = h * in_w + w
# act = torch.zeros(act_shape, dtype=torch.bfloat16)
# act = torch.ones(act_shape, dtype=torch.bfloat16)
# act = torch.arange(0, volume(act_shape), dtype=torch.bfloat16).reshape(act_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,30 @@
#if DEBUG_PRINT == 1
#include "debug/dprint.h"
// #include "debug_macros.h"
#define dump_unpack(a) \
do { DPRINT_UNPACK(DPRINT << "UP: "<< #a " = " << a << ENDL()); } while(false)
#define dump_pack(a) \
do { DPRINT_PACK(DPRINT << "P: "<< #a " = " << a << ENDL()); } while(false)
#define dump_math(a) \
do { DPRINT_MATH(DPRINT << "M: "<< #a " = " << a << ENDL()); } while(false)

// SliceRange srt = SliceRange{.h0 = 0, .h1 = 32, .hs = 8, .w0 = 0, .w1 = 32, .ws = 4};
// SliceRange srr = SliceRange{.h0 = 0, .h1 = 1, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange srr1 = SliceRange{.h0 = 1, .h1 = 2, .hs = 8, .w0 = 0, .w1 = 32, .ws = 1};
// SliceRange src = SliceRange{.h0 = 0, .h1 = 32, .hs = 1, .w0 = 0, .w1 = 1, .ws = 1};

// inline void print_tile_rows(uint32_t cb_id, uint32_t rows = 32, uint32_t tile_id = 0, bool untilize = false) {
// // UNPACK(( DPRINT << "======" << ENDL() ));
// for (uint16_t r = 0; r < rows; ++r) {
// SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1};
// // UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() ));
// UNPACK((DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize)));
// }
// // UNPACK(( DPRINT << "++++++" << ENDL() ));
// }
inline void print_tile_rows(uint32_t cb_id, uint32_t rows = 32, uint32_t tile_id = 0, bool untilize = false) {
// UNPACK(( DPRINT << "======" << ENDL() ));
for (uint16_t r = 0; r < rows; ++r) {
SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1};
// UNPACK(( DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL() ));
UNPACK((DPRINT << (uint)r << " :: " << TileSlice(cb_id, tile_id, sr, true, untilize)));
}
// UNPACK(( DPRINT << "++++++" << ENDL() ));
}

// inline void print_full_tile(uint32_t cb_id, uint32_t tile_id = 0, bool untilize = false) {
// UNPACK((DPRINT << "======" << ENDL()));
// for (uint16_t r = 0; r < 32; ++r) {
// SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1};
// UNPACK((DPRINT << (uint)r << " : " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL()));
// }
// UNPACK((DPRINT << "++++++" << ENDL()));
// }
inline void print_full_tile(uint32_t cb_id, uint32_t tile_id = 0, bool untilize = false) {
UNPACK((DPRINT << "======" << ENDL()));
for (uint16_t r = 0; r < 32; ++r) {
SliceRange sr = SliceRange{.h0 = r, .h1 = (uint16_t)(r + 1), .hs = 1, .w0 = 0, .w1 = 32, .ws = 1};
UNPACK((DPRINT << (uint)r << " : " << TileSlice(cb_id, tile_id, sr, true, untilize) << ENDL()));
}
UNPACK((DPRINT << "++++++" << ENDL()));
}

// inline void print_cb_details(uint32_t cb_id) {
// DPRINT << "cb_id " << cb_id << ": { "
Expand Down Expand Up @@ -93,13 +88,13 @@ void MAIN {
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3);
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);

constexpr uint32_t split_reader = get_compile_time_arg_val(12);

constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(13);
constexpr uint32_t in_c = get_compile_time_arg_val(14);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(15);
constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(16);

constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;
Expand All @@ -108,11 +103,10 @@ void MAIN {
constexpr uint32_t interm_reduction_cb_id = tt::CB::c_intermed1;

constexpr uint32_t MAX_TILES_PER_REDUCTION = 8;
constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16;

constexpr bool is_partial_tile = in_c < 32;
static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16");
constexpr uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : MAX_ROWS_FOR_REDUCTION < 32 ? 2 : 4;
constexpr uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : max_rows_for_reduction < 32 ? 2 : 4;
constexpr uint32_t num_faces_in_output_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;

Expand All @@ -123,9 +117,9 @@ void MAIN {
num_output_tiles,
interm_reduction_cb_id,
num_faces_in_input_tile,
MAX_ROWS_FOR_REDUCTION);
max_rows_for_reduction);

uint32_t interm_reduction_chunks = window_size_hw / MAX_ROWS_FOR_REDUCTION;
uint32_t interm_reduction_chunks = window_size_hw / max_rows_for_reduction;
cb_wait_front(in_scalar_cb_id, 1);
//cb_wait_front(interm_reduction_cb_id, 1);
cb_reserve_back(out_cb_id, 1);
Expand All @@ -144,7 +138,7 @@ void MAIN {
in_cb_id,
in_scalar_cb_id,
i,
MAX_ROWS_FOR_REDUCTION);
max_rows_for_reduction);
tile_regs_wait();
tile_regs_commit();
pack_untilize_dst<num_output_tiles>(
Expand All @@ -159,7 +153,7 @@ void MAIN {
pack_untilize_uninit(interm_reduction_cb_id);
cb_wait_front(interm_reduction_cb_id, 1);

pack_untilize_dst_init_short<out_ntiles_c / in_nblocks_c>(
pack_untilize_dst_init_short<num_output_tiles>(
out_cb_id, num_out_rows, num_faces_in_output_tile);

tile_regs_acquire();
Expand All @@ -169,7 +163,7 @@ void MAIN {
num_output_tiles,
0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/,
num_faces_in_input_tile /* unpack 1 or 2 faces ) */,
MAX_ROWS_FOR_REDUCTION);
max_rows_for_reduction);
for (uint32_t c_i = 0; c_i < num_output_tiles; ++c_i) {
reduce_tile_math(c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ void kernel_main() {
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);
constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);
constexpr uint32_t in_cb_sz = get_compile_time_arg_val(14);
constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(15);

constexpr uint32_t TILE_SIZE = 32 * 32;
constexpr uint32_t MAX_TILES_PER_REDUCTION = 8;
constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16;
constexpr uint32_t MAX_ELE_PER_REDUCTION = 512;
constexpr uint32_t ROW_HW = 64;

Expand Down Expand Up @@ -109,15 +109,15 @@ void kernel_main() {
}
uint32_t counter = reader_id;
uint32_t total_elems_to_reduce = window_h * window_w;
uint32_t remaining_elems = total_elems_to_reduce % MAX_ROWS_FOR_REDUCTION;
uint32_t remaining_elems = total_elems_to_reduce % max_rows_for_reduction;
while (counter < reader_nindices) {
for (uint32_t c_i = 0; c_i < in_nblocks_c; c_i++) {
uint16_t top_left_local_index = reader_indices_ptr[counter];
uint32_t processed_rows = 0;
cb_reserve_back(in_cb_id, npages_to_reserve);
uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id);
uint32_t out_l1_write_addr = out_l1_write_addr_base;
if ((total_elems_to_reduce - processed_rows) < MAX_ROWS_FOR_REDUCTION)
if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction)
fill_with_val(out_l1_write_addr, in_cb_sz, minus_inf);
for (uint32_t h = 0; h < window_h; ++h) {
for (uint32_t w = 0; w < window_w; w++) {
Expand All @@ -127,14 +127,14 @@ void kernel_main() {
noc_async_read_one_packet(get_noc_addr(read_offset), out_l1_write_addr, read_bytes);
out_l1_write_addr += read_bytes;
processed_rows++;
if ((processed_rows % MAX_ROWS_FOR_REDUCTION) == 0) {
if ((processed_rows % max_rows_for_reduction) == 0) {
noc_async_read_barrier();
cb_push_back(in_cb_id, npages_to_reserve);
cb_reserve_back(in_cb_id, npages_to_reserve);
out_l1_write_addr_base = get_write_ptr(in_cb_id);
out_l1_write_addr = out_l1_write_addr_base;
// If next is last chunk, fill whole buffer with -inf.
if ((total_elems_to_reduce - processed_rows) < MAX_ROWS_FOR_REDUCTION)
if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction)
fill_with_val(out_l1_write_addr, in_cb_sz, minus_inf);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <cstring>
#include "dataflow_api.h"

#define ENABLE_DEBUG_PRINT 1
#define ENABLE_DEBUG_PRINT 0

#if ENABLE_DEBUG_PRINT == 1
#include "debug/dprint.h"
Expand Down Expand Up @@ -64,8 +64,6 @@ void kernel_main() {

constexpr uint32_t in_nblocks_c = get_compile_time_arg_val(13);

// static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2");

constexpr uint32_t TILE_WIDTH = 32;

constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0;
Expand All @@ -89,12 +87,6 @@ void kernel_main() {
uint32_t reader_indices_l1_addr = get_read_ptr(in_reader_indices_cb_id);
volatile tt_l1_ptr uint16_t* reader_indices_ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(reader_indices_l1_addr);

/* if (reader_id == 0) {
DPRINT << "in_nbtes_c: " << in_nbytes_c << ENDL();
DPRINT << "in_cb_nsticks: " << in_cb_nsticks << ENDL();
print_pages(in_l1_read_base_addr, in_nbytes_c / 2, (10 + 2 * pad_w) * (10 + 2 * pad_w));
} */

uint32_t in_w_padded = in_w + 2 * pad_w;

uint32_t npages_to_reserve = 1;
Expand All @@ -104,9 +96,6 @@ void kernel_main() {
uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id);
uint32_t out_l1_write_addr = out_l1_write_addr_base;
uint16_t top_left_local_index = reader_indices_ptr[counter ++];
if (reader_id == 0) {
/*DPRINT << "top_left_local_index: " << top_left_local_index << ENDL();*/
}
uint32_t h_multiples = 0;
for (uint32_t h = 0; h < window_h; ++ h, h_multiples += in_w_padded) {
uint32_t stick_offset = top_left_local_index + h_multiples;
Expand All @@ -116,10 +105,6 @@ void kernel_main() {
}
if (split_reader) counter++; // interleave the indices
noc_async_read_barrier();
if (reader_id == 0) {
/*DPRINT << "out_l1: " << ENDL();*/
/*print_pages(out_l1_write_addr_base, in_nbytes_c / 2, window_h * window_w);*/
}
cb_push_back(in_cb_id, npages_to_reserve);
}
} // kernel_main()
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,11 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
input_shape[3] / num_shards_c,
nblocks,
split_reader, // enable split reader
0, // split reader id
0, // split reader id,
bf16_one_u32,
in_nblocks_c,
in_cb_sz};
in_cb_sz,
max_rows_for_reduction};

std::vector<uint32_t> reader1_ct_args = {
out_nhw_per_core,
Expand All @@ -312,7 +313,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
1, // split reader id
bf16_one_u32,
in_nblocks_c,
in_cb_sz};
in_cb_sz,
max_rows_for_reduction};

std::string reader_kernel_fname;
if (is_large_kernel) {
Expand Down Expand Up @@ -354,7 +356,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
split_reader, // enable split reader
out_nhw_per_core / nblocks, // loop count with blocks
input_shape[3] / num_shards_c,
in_nblocks_c};
in_nblocks_c,
max_rows_for_reduction};

auto reduce_op = tt::tt_metal::ReduceOpMath::MAX;
auto reduce_dim = tt::tt_metal::ReduceOpDim::H;
Expand Down

0 comments on commit f625a25

Please sign in to comment.