Skip to content

Commit

Permalink
Add gatekeeper to manage sockets.
Browse files Browse the repository at this point in the history
    Add apis to Open/Connect datagram sockets.
    Add Socket sanity test.
  • Loading branch information
ubcheema authored and aagarwalTT committed Jan 21, 2025
1 parent 5c8604a commit 6eda798
Show file tree
Hide file tree
Showing 12 changed files with 2,825 additions and 158 deletions.
1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/perf_microbenchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ set(PERF_MICROBENCH_TESTS_SRCS
routing/test_vc_loopback_tunnel.cpp
routing/test_tt_fabric_sanity.cpp
routing/test_tt_fabric_multi_hop_sanity.cpp
routing/test_tt_fabric_socket_sanity.cpp
noc/test_noc_unicast_vs_multicast_to_single_core_latency.cpp
old/matmul/matmul_global_l1.cpp
old/matmul/matmul_local_l1.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

// clang-format off
#include "debug/dprint.h"
#include "dataflow_api.h"
#include "tt_fabric/hw/inc/tt_fabric.h"
#include "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen.hpp"
#include "tt_fabric/hw/inc/tt_fabric_interface.h"
#include "tt_fabric/hw/inc/tt_fabric_api.h"
// clang-format on

// seed to re-generate the data and validate against incoming data
constexpr uint32_t prng_seed = get_compile_time_arg_val(0);

// total data/payload expected
constexpr uint32_t total_data_kb = get_compile_time_arg_val(1);
constexpr uint64_t total_data_words = ((uint64_t)total_data_kb) * 1024 / PACKET_WORD_SIZE_BYTES;

// max packet size to generate mask
constexpr uint32_t max_packet_size_words = get_compile_time_arg_val(2);
static_assert(max_packet_size_words > 3, "max_packet_size_words must be greater than 3");

// fabric command
constexpr uint32_t test_command = get_compile_time_arg_val(3);

// address to start reading from/poll on
constexpr uint32_t target_address = get_compile_time_arg_val(4);

// atomic increment for the ATOMIC_INC command
constexpr uint32_t atomic_increment = get_compile_time_arg_val(5);

constexpr uint32_t test_results_addr_arg = get_compile_time_arg_val(6);
constexpr uint32_t test_results_size_bytes = get_compile_time_arg_val(7);
constexpr uint32_t gk_interface_addr_l = get_compile_time_arg_val(8);
constexpr uint32_t gk_interface_addr_h = get_compile_time_arg_val(9);
constexpr uint32_t client_interface_addr = get_compile_time_arg_val(10);
constexpr uint32_t client_pull_req_buf_addr = get_compile_time_arg_val(11);
constexpr uint32_t data_buffer_start_addr = get_compile_time_arg_val(12);
constexpr uint32_t data_buffer_size_words = get_compile_time_arg_val(13);

volatile tt_l1_ptr chan_req_buf* client_pull_req_buf =
reinterpret_cast<tt_l1_ptr chan_req_buf*>(client_pull_req_buf_addr);
volatile tt_fabric_client_interface_t* client_interface = (volatile tt_fabric_client_interface_t*)client_interface_addr;
uint64_t xy_local_addr;
socket_reader_state socket_reader;

tt_l1_ptr uint32_t* const test_results = reinterpret_cast<tt_l1_ptr uint32_t*>(test_results_addr_arg);

#define PAYLOAD_MASK (0xFFFF0000)

void kernel_main() {
uint64_t processed_packet_words = 0, num_packets = 0;
volatile tt_l1_ptr uint32_t* poll_addr;
uint32_t poll_val = 0;
bool async_wr_check_failed = false;

// parse runtime args
uint32_t dest_device = get_arg_val<uint32_t>(0);

tt_fabric_init();

zero_l1_buf(test_results, test_results_size_bytes);
test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_STARTED;
test_results[PQ_TEST_MISC_INDEX] = 0xff000000;
zero_l1_buf(
reinterpret_cast<tt_l1_ptr uint32_t*>(data_buffer_start_addr), data_buffer_size_words * PACKET_WORD_SIZE_BYTES);
test_results[PQ_TEST_MISC_INDEX] = 0xff000001;
zero_l1_buf((uint32_t*)client_interface, sizeof(tt_fabric_client_interface_t));
test_results[PQ_TEST_MISC_INDEX] = 0xff000002;
zero_l1_buf((uint32_t*)client_pull_req_buf, sizeof(chan_req_buf));
test_results[PQ_TEST_MISC_INDEX] = 0xff000003;

client_interface->gk_interface_addr = ((uint64_t)gk_interface_addr_h << 32) | gk_interface_addr_l;
client_interface->gk_msg_buf_addr = client_interface->gk_interface_addr + offsetof(gatekeeper_info_t, gk_msg_buf);
client_interface->pull_req_buf_addr = xy_local_addr | client_pull_req_buf_addr;
test_results[PQ_TEST_MISC_INDEX] = 0xff000004;

// make sure fabric node gatekeeper is available.
fabric_endpoint_init();

socket_reader.init(data_buffer_start_addr, data_buffer_size_words);
DPRINT << "Socket open on " << dest_device << ENDL();
test_results[PQ_TEST_MISC_INDEX] = 0xff000005;

fabric_socket_open(
3, // the network plane to use for this socket
2, // Temporal epoch for which the socket is being opened
1, // Socket Id to open
SOCKET_TYPE_DGRAM, // Unicast, Multicast, SSocket, DSocket
SOCKET_DIRECTION_RECV, // Send or Receive
dest_device >> 16, // Remote mesh/device that is the socket data sender/receiver.
dest_device & 0xFFFF,
0 // fabric virtual channel.
);
test_results[PQ_TEST_MISC_INDEX] = 0xff000006;

uint32_t loop_count = 0;
uint32_t packet_count = 0;
while (1) {
if (!fvc_req_buf_is_empty(client_pull_req_buf) && fvc_req_valid(client_pull_req_buf)) {
uint32_t req_index = client_pull_req_buf->rdptr.ptr & CHAN_REQ_BUF_SIZE_MASK;
chan_request_entry_t* req = (chan_request_entry_t*)client_pull_req_buf->chan_req + req_index;
pull_request_t* pull_req = &req->pull_request;
if (socket_reader.packet_in_progress == 0) {
DPRINT << "Socket Packet " << packet_count << ENDL();
}
if (pull_req->flags == FORWARD) {
socket_reader.pull_socket_data(pull_req);
test_results[PQ_TEST_MISC_INDEX] = 0xDD000001;
noc_async_read_barrier();
update_pull_request_words_cleared(pull_req);
socket_reader.pull_words_in_flight = 0;
socket_reader.push_socket_data<false>();
}

if (socket_reader.packet_in_progress == 1 and socket_reader.packet_words_remaining == 0) {
// wait for any pending sockat data writes to finish.
test_results[PQ_TEST_MISC_INDEX] = 0xDD000002;

noc_async_write_barrier();

test_results[PQ_TEST_MISC_INDEX] = 0xDD000003;
// clear the flags field to invalidate pull request slot.
// flags will be set to non-zero by next requestor.
req_buf_advance_rdptr((chan_req_buf*)client_pull_req_buf);
socket_reader.packet_in_progress = 0;
packet_count++;
loop_count = 0;
}
}
test_results[PQ_TEST_MISC_INDEX] = 0xDD400000 | (loop_count & 0xfffff);

loop_count++;
if (packet_count > 0 and loop_count >= 0x10000) {
DPRINT << "Socket Rx Finished" << packet_count << ENDL();
break;
}
}

// write out results
set_64b_result(test_results, processed_packet_words, PQ_TEST_WORD_CNT_INDEX);
set_64b_result(test_results, num_packets, TX_TEST_IDX_NPKT);

if (async_wr_check_failed) {
test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_DATA_MISMATCH;
} else {
test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_PASS;
test_results[PQ_TEST_MISC_INDEX] = 0xff000005;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "tt_fabric/hw/inc/tt_fabric.h"
#include "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen.hpp"
#include "tt_fabric/hw/inc/tt_fabric_interface.h"
#include "tt_fabric/hw/inc/tt_fabric_api.h"
// clang-format on

uint32_t src_endpoint_id;
Expand All @@ -20,8 +21,8 @@ constexpr uint32_t data_buffer_start_addr = get_compile_time_arg_val(3);
constexpr uint32_t data_buffer_size_words = get_compile_time_arg_val(4);

constexpr uint32_t routing_table_start_addr = get_compile_time_arg_val(5);
// constexpr uint32_t router_x = get_compile_time_arg_val(6);
// constexpr uint32_t router_y = get_compile_time_arg_val(7);
constexpr uint32_t gk_interface_addr_l = get_compile_time_arg_val(6);
constexpr uint32_t gk_interface_addr_h = get_compile_time_arg_val(7);

constexpr uint32_t test_results_addr_arg = get_compile_time_arg_val(8);
constexpr uint32_t test_results_size_bytes = get_compile_time_arg_val(9);
Expand Down Expand Up @@ -55,15 +56,18 @@ constexpr uint32_t atomic_increment = get_compile_time_arg_val(20);
uint32_t dest_device;

constexpr uint32_t signal_address = get_compile_time_arg_val(21);
constexpr uint32_t client_interface_addr = get_compile_time_arg_val(22);

uint32_t max_packet_size_mask;

auto input_queue_state = select_input_queue<pkt_dest_size_choice>();
volatile local_pull_request_t *local_pull_request = (volatile local_pull_request_t *)(data_buffer_start_addr - 1024);
tt_l1_ptr volatile tt::tt_fabric::fabric_router_l1_config_t* routing_table =
reinterpret_cast<tt_l1_ptr tt::tt_fabric::fabric_router_l1_config_t*>(routing_table_start_addr);
volatile tt_fabric_client_interface_t* client_interface = (volatile tt_fabric_client_interface_t*)client_interface_addr;

fvc_producer_state_t test_producer __attribute__((aligned(16)));
fvcc_inbound_state_t fvcc_test_producer __attribute__((aligned(16)));

uint64_t xy_local_addr;

Expand Down Expand Up @@ -243,11 +247,82 @@ inline bool test_buffer_handler_atomic_inc() {
return false;
}

inline bool test_buffer_handler_fvcc() {
if (input_queue_state.all_packets_done()) {
return true;
}

uint32_t free_words = fvcc_test_producer.get_num_msgs_free() * PACKET_HEADER_SIZE_WORDS;
if (free_words < PACKET_HEADER_SIZE_WORDS) {
return false;
}

uint32_t byte_wr_addr = fvcc_test_producer.get_local_buffer_write_addr();
uint32_t words_to_init = std::min(free_words, fvcc_test_producer.words_before_local_buffer_wrap());
uint32_t words_initialized = 0;
while (words_initialized < words_to_init) {
if (input_queue_state.all_packets_done()) {
break;
}

if (!input_queue_state.packet_active()) { // start of a new packet
input_queue_state.next_inline_packet(total_data_words);

tt_l1_ptr uint32_t* header_ptr = reinterpret_cast<tt_l1_ptr uint32_t*>(byte_wr_addr);

packet_header.routing.flags = SYNC;
packet_header.routing.dst_mesh_id = dest_device >> 16;
packet_header.routing.dst_dev_id = dest_device & 0xFFFF;
packet_header.routing.src_dev_id = routing_table->my_device_id;
packet_header.routing.src_mesh_id = routing_table->my_mesh_id;
packet_header.routing.packet_size_bytes = PACKET_HEADER_SIZE_BYTES;
packet_header.session.command = ASYNC_WR_RESP;
packet_header.session.target_offset_l = target_address;
packet_header.session.target_offset_h = noc_offset;
packet_header.packet_parameters.misc_parameters.words[1] = 0;
packet_header.packet_parameters.misc_parameters.words[2] = 0;
tt_fabric_add_header_checksum(&packet_header);
uint32_t words_left = words_to_init - words_initialized;
bool split_header = words_left < PACKET_HEADER_SIZE_WORDS;
uint32_t header_words_to_init = PACKET_HEADER_SIZE_WORDS;
if (split_header) {
header_words_to_init = words_left;
}
for (uint32_t i = 0; i < (header_words_to_init * PACKET_WORD_SIZE_BYTES / 4); i++) {
header_ptr[i] = ((uint32_t*)&packet_header)[i];
}

words_initialized += header_words_to_init;
input_queue_state.curr_packet_words_remaining -= header_words_to_init;
byte_wr_addr += header_words_to_init * PACKET_WORD_SIZE_BYTES;
} else {
tt_l1_ptr uint32_t* header_ptr = reinterpret_cast<tt_l1_ptr uint32_t*>(byte_wr_addr);
uint32_t header_words_initialized =
input_queue_state.curr_packet_size_words - input_queue_state.curr_packet_words_remaining;
uint32_t header_words_to_init = PACKET_HEADER_SIZE_WORDS - header_words_initialized;
uint32_t header_dword_index = header_words_initialized * PACKET_WORD_SIZE_BYTES / 4;
uint32_t words_left = words_to_init - words_initialized;
header_words_to_init = std::min(words_left, header_words_to_init);

for (uint32_t i = 0; i < (header_words_to_init * PACKET_WORD_SIZE_BYTES / 4); i++) {
header_ptr[i] = ((uint32_t*)&packet_header)[i + header_dword_index];
}
words_initialized += header_words_to_init;
input_queue_state.curr_packet_words_remaining -= header_words_to_init;
byte_wr_addr += header_words_to_init * PACKET_WORD_SIZE_BYTES;
}
}
fvcc_test_producer.advance_local_wrptr(words_initialized / PACKET_HEADER_SIZE_WORDS);
return false;
}

bool test_buffer_handler() {
if constexpr (test_command == ASYNC_WR) {
return test_buffer_handler_async_wr();
} else if constexpr (test_command == ATOMIC_INC) {
return test_buffer_handler_atomic_inc();
} else if constexpr (test_command == SOCKET_OPEN) {
return test_buffer_handler_fvcc();
}
}

Expand Down Expand Up @@ -284,6 +359,9 @@ void kernel_main() {
zero_l1_buf(reinterpret_cast<tt_l1_ptr uint32_t*>(data_buffer_start_addr), data_buffer_size_words * PACKET_WORD_SIZE_BYTES);
zero_l1_buf((uint32_t*)local_pull_request, sizeof(local_pull_request_t));
zero_l1_buf((uint32_t*)&packet_header, sizeof(packet_header_t));
zero_l1_buf((uint32_t*)client_interface, sizeof(tt_fabric_client_interface_t));
client_interface->gk_msg_buf_addr =
(((uint64_t)gk_interface_addr_h << 32) | gk_interface_addr_l) + offsetof(gatekeeper_info_t, gk_msg_buf);

if constexpr (pkt_dest_size_choice == pkt_dest_size_choices_t::RANDOM) {
input_queue_state.init(src_endpoint_id, prng_seed);
Expand All @@ -294,6 +372,7 @@ void kernel_main() {
}

test_producer.init(data_buffer_start_addr, data_buffer_size_words, 0x0);
fvcc_test_producer.init(data_buffer_start_addr, 0x0, 0x0);

uint32_t temp = max_packet_size_words;
max_packet_size_mask = 0;
Expand Down Expand Up @@ -364,6 +443,15 @@ void kernel_main() {
DPRINT << "Packet Header Corrupted: packet " << packet_count
<< " Addr: " << test_producer.get_local_buffer_read_addr() << ENDL();
break;
} else if (fvcc_test_producer.get_curr_packet_valid()) {
fvcc_test_producer.fvcc_handler<FVC_MODE_ENDPOINT>();
#ifdef CHECK_TIMEOUT
progress_timestamp = get_timestamp_32b();
#endif
} else if (fvcc_test_producer.packet_corrupted) {
DPRINT << "Packet Header Corrupted: packet " << packet_count
<< " Addr: " << fvcc_test_producer.get_local_buffer_read_addr() << ENDL();
break;
} else if (all_packets_initialized) {
DPRINT << "all packets done" << ENDL();
break;
Expand Down
Loading

0 comments on commit 6eda798

Please sign in to comment.