Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TransferEngine] adjust transfer_engine_bench: 1.Introduce the gflag buffer_size for enhanced configurability. 2. Utilize uint64_t for block_size to prevent overflow. #72

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions mooncake-transfer-engine/example/transfer_engine_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ DEFINE_string(nic_priority_matrix, "",
"Path to RDMA NIC priority matrix file (Advanced)");

DEFINE_string(segment_id, "192.168.3.76", "Segment ID to access data");
DEFINE_uint64(buffer_size, 1ull << 30, "total size of data buffer");
DEFINE_int32(batch_size, 128, "Batch size");
DEFINE_int32(block_size, 4096, "Block size for each transfer request");
DEFINE_uint64(block_size, 4096, "Block size for each transfer request");
DEFINE_int32(duration, 10, "Test duration in seconds");
DEFINE_int32(threads, 4, "Task submission threads");

Expand Down Expand Up @@ -226,7 +227,6 @@ std::string loadNicPriorityMatrix() {
}

int initiator() {
const size_t ram_buffer_size = 1ull << 30;
auto engine = std::make_unique<TransferEngine>();

auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name);
Expand Down Expand Up @@ -255,16 +255,16 @@ int initiator() {
buffer_num = FLAGS_use_vram ? 1 : NR_SOCKETS;
if (FLAGS_use_vram) LOG(INFO) << "VRAM is used";
for (int i = 0; i < buffer_num; ++i) {
addr[i] = allocateMemoryPool(ram_buffer_size, i, FLAGS_use_vram);
addr[i] = allocateMemoryPool(FLAGS_buffer_size, i, FLAGS_use_vram);
std::string name_prefix = FLAGS_use_vram ? "gpu:" : "cpu:";
int rc = engine->registerLocalMemory(addr[i], ram_buffer_size,
int rc = engine->registerLocalMemory(addr[i], FLAGS_buffer_size,
name_prefix + std::to_string(i));
LOG_ASSERT(!rc);
}
#else
for (int i = 0; i < buffer_num; ++i) {
addr[i] = allocateMemoryPool(ram_buffer_size, i, false);
int rc = engine->registerLocalMemory(addr[i], ram_buffer_size,
addr[i] = allocateMemoryPool(FLAGS_buffer_size, i, false);
int rc = engine->registerLocalMemory(addr[i], FLAGS_buffer_size,
"cpu:" + std::to_string(i));
LOG_ASSERT(!rc);
}
Expand Down Expand Up @@ -299,14 +299,13 @@ int initiator() {

for (int i = 0; i < buffer_num; ++i) {
engine->unregisterLocalMemory(addr[i]);
freeMemoryPool(addr[i], ram_buffer_size);
freeMemoryPool(addr[i], FLAGS_buffer_size);
}

return 0;
}

int target() {
const size_t ram_buffer_size = 1ull << 30;
auto engine = std::make_unique<TransferEngine>();

auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name);
Expand All @@ -327,9 +326,9 @@ int target() {

void *addr[NR_SOCKETS] = {nullptr};
for (int i = 0; i < NR_SOCKETS; ++i) {
addr[i] = allocateMemoryPool(ram_buffer_size, i);
memset(addr[i], 'x', ram_buffer_size);
int rc = engine->registerLocalMemory(addr[i], ram_buffer_size,
addr[i] = allocateMemoryPool(FLAGS_buffer_size, i);
memset(addr[i], 'x', FLAGS_buffer_size);
int rc = engine->registerLocalMemory(addr[i], FLAGS_buffer_size,
"cpu:" + std::to_string(i));
LOG_ASSERT(!rc);
}
Expand All @@ -338,14 +337,25 @@ int target() {

for (int i = 0; i < NR_SOCKETS; ++i) {
engine->unregisterLocalMemory(addr[i]);
freeMemoryPool(addr[i], ram_buffer_size);
freeMemoryPool(addr[i], FLAGS_buffer_size);
}

return 0;
}

void check_total_buffer_size() {
uint64_t require_size = FLAGS_block_size * FLAGS_batch_size * FLAGS_threads;
if (FLAGS_buffer_size < require_size) {
FLAGS_buffer_size = require_size;
LOG(WARNING) << "Invalid flag: buffer size is samller than "
"require_size, adjust to "
<< require_size;
}
}

int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
check_total_buffer_size();

if (FLAGS_mode == "initiator")
return initiator();
Expand Down