Skip to content

Commit

Permalink
Rebase and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aagarwalTT committed Jan 21, 2025
1 parent 6eda798 commit fb191aa
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,40 @@ constexpr uint32_t data_buffer_start_addr = get_compile_time_arg_val(3);
constexpr uint32_t data_buffer_size_words = get_compile_time_arg_val(4);

constexpr uint32_t routing_table_start_addr = get_compile_time_arg_val(5);
constexpr uint32_t gk_interface_addr_l = get_compile_time_arg_val(6);
constexpr uint32_t gk_interface_addr_h = get_compile_time_arg_val(7);

constexpr uint32_t test_results_addr_arg = get_compile_time_arg_val(8);
constexpr uint32_t test_results_size_bytes = get_compile_time_arg_val(9);
constexpr uint32_t test_results_addr_arg = get_compile_time_arg_val(6);
constexpr uint32_t test_results_size_bytes = get_compile_time_arg_val(7);

tt_l1_ptr uint32_t* const test_results = reinterpret_cast<tt_l1_ptr uint32_t*>(test_results_addr_arg);

constexpr uint32_t prng_seed = get_compile_time_arg_val(10);
constexpr uint32_t prng_seed = get_compile_time_arg_val(8);

constexpr uint32_t total_data_kb = get_compile_time_arg_val(11);
constexpr uint32_t total_data_kb = get_compile_time_arg_val(9);
constexpr uint64_t total_data_words = ((uint64_t)total_data_kb) * 1024 / PACKET_WORD_SIZE_BYTES;

constexpr uint32_t max_packet_size_words = get_compile_time_arg_val(12);
constexpr uint32_t max_packet_size_words = get_compile_time_arg_val(10);

static_assert(max_packet_size_words > 3, "max_packet_size_words must be greater than 3");

constexpr uint32_t timeout_cycles = get_compile_time_arg_val(13);
constexpr uint32_t timeout_cycles = get_compile_time_arg_val(11);

constexpr bool skip_pkt_content_gen = get_compile_time_arg_val(14);
constexpr bool skip_pkt_content_gen = get_compile_time_arg_val(12);
constexpr pkt_dest_size_choices_t pkt_dest_size_choice =
static_cast<pkt_dest_size_choices_t>(get_compile_time_arg_val(15));
static_cast<pkt_dest_size_choices_t>(get_compile_time_arg_val(13));

constexpr uint32_t data_sent_per_iter_low = get_compile_time_arg_val(16);
constexpr uint32_t data_sent_per_iter_high = get_compile_time_arg_val(17);
constexpr uint32_t test_command = get_compile_time_arg_val(18);
constexpr uint32_t data_sent_per_iter_low = get_compile_time_arg_val(14);
constexpr uint32_t data_sent_per_iter_high = get_compile_time_arg_val(15);
constexpr uint32_t test_command = get_compile_time_arg_val(16);

uint32_t base_target_address = get_compile_time_arg_val(19);
uint32_t base_target_address = get_compile_time_arg_val(17);

// atomic increment for the ATOMIC_INC command
constexpr uint32_t atomic_increment = get_compile_time_arg_val(20);
constexpr uint32_t atomic_increment = get_compile_time_arg_val(18);
// constexpr uint32_t dest_device = get_compile_time_arg_val(21);
uint32_t dest_device;

constexpr uint32_t signal_address = get_compile_time_arg_val(21);
constexpr uint32_t client_interface_addr = get_compile_time_arg_val(22);
constexpr uint32_t signal_address = get_compile_time_arg_val(19);
constexpr uint32_t client_interface_addr = get_compile_time_arg_val(20);

uint32_t max_packet_size_mask;

Expand All @@ -77,6 +75,9 @@ uint32_t target_address;
uint32_t noc_offset;
uint32_t rx_addr_hi;

uint32_t gk_interface_addr_l;
uint32_t gk_interface_addr_h;

// generates packets with random size and payload on the input side
inline bool test_buffer_handler_async_wr() {
if (input_queue_state.all_packets_done()) {
Expand Down Expand Up @@ -321,24 +322,27 @@ bool test_buffer_handler() {
return test_buffer_handler_async_wr();
} else if constexpr (test_command == ATOMIC_INC) {
return test_buffer_handler_atomic_inc();
} else if constexpr (test_command == SOCKET_OPEN) {
} else if constexpr (test_command == ASYNC_WR_RESP) {
return test_buffer_handler_fvcc();
}
}

void kernel_main() {
tt_fabric_init();

uint32_t rt_args_idx = 0;
// TODO: refactor
src_endpoint_id = get_arg_val<uint32_t>(0);
noc_offset = get_arg_val<uint32_t>(1);
uint32_t router_x = get_arg_val<uint32_t>(2);
uint32_t router_y = get_arg_val<uint32_t>(3);
dest_device = get_arg_val<uint32_t>(4);
uint32_t rx_buf_size = get_arg_val<uint32_t>(5);
src_endpoint_id = get_arg_val<uint32_t>(rt_args_idx++);
noc_offset = get_arg_val<uint32_t>(rt_args_idx++);
uint32_t router_x = get_arg_val<uint32_t>(rt_args_idx++);
uint32_t router_y = get_arg_val<uint32_t>(rt_args_idx++);
dest_device = get_arg_val<uint32_t>(rt_args_idx++);
uint32_t rx_buf_size = get_arg_val<uint32_t>(rt_args_idx++);
gk_interface_addr_l = get_arg_val<uint32_t>(rt_args_idx++);
gk_interface_addr_h = get_arg_val<uint32_t>(rt_args_idx++);

if (ASYNC_WR == test_command) {
base_target_address = get_arg_val<uint32_t>(6);
base_target_address = get_arg_val<uint32_t>(rt_args_idx++);
}
target_address = base_target_address;
rx_addr_hi = base_target_address + rx_buf_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ int main(int argc, char** argv) {
uint32_t routing_table_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED);
uint32_t gk_interface_addr = routing_table_addr + sizeof(tt::tt_fabric::fabric_router_l1_config_t) * 4;
uint32_t client_interface_addr = routing_table_addr + sizeof(tt::tt_fabric::fabric_router_l1_config_t) * 4;
;
uint32_t socket_info_addr = gk_interface_addr + sizeof(gatekeeper_info_t);
log_info(LogTest, "GK Routing Table Addr = 0x{:08X}", routing_table_addr);
log_info(LogTest, "GK Info Addr = 0x{:08X}", gk_interface_addr);
Expand Down Expand Up @@ -338,28 +337,33 @@ int main(int argc, char** argv) {
log_info(LogTest, "Device {} router_mask = 0x{:04X}", device.first, router_mask);
uint32_t sem_count = device_router_cores.size();
device_router_map[device.first] = device_router_phys_cores;
std::vector<uint32_t> runtime_args = {
sem_count, // 0: number of active fabric routers
router_mask, // 1: active fabric router mask
};

gk_phys_core = (device.second->worker_core_from_logical_core(gk_core));
uint32_t gk_noc_offset = tt_metal::hal.noc_xy_encoding(gk_phys_core.x, gk_phys_core.y);

std::vector<uint32_t> router_compile_args = {
(tunneler_queue_size_bytes >> 4), // 0: rx_queue_size_words
tunneler_test_results_addr, // 1: test_results_addr
tunneler_test_results_size, // 2: test_results_size
0, // 3: timeout_cycles
};

std::vector<uint32_t> router_runtime_args = {
sem_count, // 0: number of active fabric routers
router_mask, // 1: active fabric router mask
gk_interface_addr, // 2: gk_message_addr_l
gk_noc_offset, // 3: gk_message_addr_h
};

for (auto logical_core : device_router_cores) {
std::vector<uint32_t> router_compile_args = {
(tunneler_queue_size_bytes >> 4), // 0: rx_queue_size_words
gk_interface_addr, // 1: gk_message_addr_l
(gk_phys_core.y << 10) | (gk_phys_core.x << 4), // 2: gk_message_addr_h
tunneler_test_results_addr, // 3: test_results_addr
tunneler_test_results_size, // 4: test_results_size
0, // 5: timeout_cycles
};
auto router_kernel = tt_metal::CreateKernel(
program_map[device.first],
"tt_fabric/impl/kernels/tt_fabric_router.cpp",
logical_core,
tt_metal::EthernetConfig{
.noc = tt_metal::NOC::NOC_0, .compile_args = router_compile_args, .defines = defines});

tt_metal::SetRuntimeArgs(program_map[device.first], router_kernel, logical_core, runtime_args);
tt_metal::SetRuntimeArgs(program_map[device.first], router_kernel, logical_core, router_runtime_args);

log_debug(
LogTest,
Expand All @@ -378,6 +382,11 @@ int main(int argc, char** argv) {
0, // 5: timeout_cycles
};

std::vector<uint32_t> gk_runtime_args = {
sem_count, // 0: number of active fabric routers
router_mask, // 1: active fabric router mask
};

auto kernel = tt_metal::CreateKernel(
program_map[device.first],
"tt_fabric/impl/kernels/tt_fabric_gatekeeper.cpp",
Expand All @@ -388,7 +397,7 @@ int main(int argc, char** argv) {
.compile_args = gk_compile_args,
.defines = defines});

tt_metal::SetRuntimeArgs(program_map[device.first], kernel, gk_core, runtime_args);
tt_metal::SetRuntimeArgs(program_map[device.first], kernel, gk_core, gk_runtime_args);
}

if (check_txrx_timeout) {
Expand All @@ -407,33 +416,36 @@ int main(int argc, char** argv) {
tx_queue_start_addr, // 3: queue_start_addr_words
(tx_queue_size_bytes >> 4), // 4: queue_size_words
routing_table_start_addr, // 5: routeing table
gk_interface_addr, // 6: gk_message_addr_l
(tx_gk_phys_core.y << 10) | (tx_gk_phys_core.x << 4), // 7: gk_message_addr_h
test_results_addr, // 8: test_results_addr
test_results_size, // 9: test_results_size
prng_seed, // 10: prng_seed
data_kb_per_tx, // 11: total_data_kb
max_packet_size_words, // 12: max_packet_size_words
timeout_mcycles * 1000 * 1000 * 4, // 13: timeout_cycles
tx_skip_pkt_content_gen, // 14: skip_pkt_content_gen
tx_pkt_dest_size_choice, // 15: pkt_dest_size_choice
tx_data_sent_per_iter_low, // 16: data_sent_per_iter_low
tx_data_sent_per_iter_high, // 17: data_sent_per_iter_high
fabric_command, // 18: fabric command
target_address,
atomic_increment,
tx_signal_address,
test_results_addr, // 6: test_results_addr
test_results_size, // 7: test_results_size
prng_seed, // 8: prng_seed
data_kb_per_tx, // 9: total_data_kb
max_packet_size_words, // 10: max_packet_size_words
timeout_mcycles * 1000 * 1000 * 4, // 11: timeout_cycles
tx_skip_pkt_content_gen, // 12: skip_pkt_content_gen
tx_pkt_dest_size_choice, // 13: pkt_dest_size_choice
tx_data_sent_per_iter_low, // 14: data_sent_per_iter_low
tx_data_sent_per_iter_high, // 15: data_sent_per_iter_high
fabric_command, // 16: fabric command
target_address, // 17:
atomic_increment, // 18:
tx_signal_address, // 19:
client_interface_addr,

};

// setup runtime args
uint32_t tx_gk_noc_offset = tt_metal::hal.noc_xy_encoding(tx_gk_phys_core.x, tx_gk_phys_core.y);
std::vector<uint32_t> runtime_args = {
(device_map[test_device_id_l]->id() << 8) + src_endpoint_start_id + i, // 0: src_endpoint_id
0x410, // 1: dest_noc_offset
router_phys_core.x,
router_phys_core.y,
(dev_r_mesh_id << 16 | dev_r_chip_id)};
router_phys_core.x, // 2: router_x
router_phys_core.y, // 3: router_y
(dev_r_mesh_id << 16 | dev_r_chip_id), // 4: mesh and chip id
0xd0000, // 5: space in rx's L1
gk_interface_addr, // 6: gk_message_addr_l
tx_gk_noc_offset, // 7: gk_message_addr_h
};

if (ASYNC_WR == fabric_command) {
runtime_args.push_back(target_address);
Expand Down
Loading

0 comments on commit fb191aa

Please sign in to comment.