Skip to content

Commit

Permalink
can work
Browse files Browse the repository at this point in the history
  • Loading branch information
XinShuoWang committed Jan 5, 2022
1 parent 5e01c9f commit 432d8d7
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 37 deletions.
30 changes: 22 additions & 8 deletions include/dijiang/Context.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
#pragma once

#include "Message.h"

#include <rdma/rdma_cma.h>

struct Context
{
struct ibv_context *ctx;
struct ibv_pd *pd;
struct ibv_cq *cq;
struct ibv_comp_channel *comp_channel;
ibv_context *ctx;
ibv_pd *pd;
ibv_cq *cq;
ibv_comp_channel *comp_channel;
};

struct ConnectionContext
struct ServerConnectionContext
{
char *buffer;
struct ibv_mr *buffer_mr;
struct Message *msg;
struct ibv_mr *msg_mr;
ibv_mr *buffer_mr;
Message *msg;
ibv_mr *msg_mr;
};

struct ClientConnectionContext
{
char *buffer;
ibv_mr *buffer_mr;

Message *msg;
ibv_mr *msg_mr;

uint64_t peer_addr;
uint32_t peer_rkey;
};
1 change: 1 addition & 0 deletions include/dijiang/Message.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cstdlib>
#include <stdint.h>

enum MessageId
{
Expand Down
198 changes: 192 additions & 6 deletions include/dijiang/RdmaClientSocket.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,203 @@
#pragma once

class RdmaClientSocket {
public:
RdmaClientSocket(const char* ip, const char * port) {
#include "Debug.hpp"
#include "RdmaClientSocketHelper.h"

class RdmaClientSocket
{
public:
RdmaClientSocket(const char *ip, const char *port, int thread_num, int buffer_size, int timeout)
{
// init
speaker_ = NULL;
channel_ = NULL;
context_ = NULL;
timeout_ = timeout;
buffer_size_ = buffer_size;
thread_pool_ = new ThreadPool(thread_num);
// set callback
RegisterCallback(buffer_size);
// resolve address
addrinfo *address = NULL;
TEST_NZ(getaddrinfo(ip, port, NULL, &address));
TEST_Z(channel_ = rdma_create_event_channel());
TEST_NZ(rdma_create_id(channel_, &speaker_, NULL, RDMA_PS_TCP));
TEST_NZ(rdma_resolve_addr(speaker_, NULL, address->ai_addr, timeout));
freeaddrinfo(address);
// set connection context
speaker_->context = (ClientConnectionContext *)malloc(sizeof(ClientConnectionContext));
}

~RdmaClientSocket() {

~RdmaClientSocket()
{
free(context_);
rdma_destroy_event_channel(channel_);
delete thread_pool_;
}

void Write(const char* data, const int size) {
void Loop()
{
rdma_conn_param params;
memset(&params, 0, sizeof(rdma_conn_param));
params.initiator_depth = params.responder_resources = 1;
params.rnr_retry_count = 7; /* infinite retry */

rdma_cm_event *event = NULL;
while (rdma_get_cm_event(channel_, &event) == 0)
{
struct rdma_cm_event event_copy;
memcpy(&event_copy, event, sizeof(rdma_cm_event));
rdma_ack_cm_event(event);
switch (event_copy.event)
{
case RDMA_CM_EVENT_ADDR_RESOLVED:
printf("RDMA_CM_EVENT_ADDR_RESOLVED \n");
InitConnection(event_copy.id);
pre_conn_cb_(event_copy.id);
TEST_NZ(rdma_resolve_route(event_copy.id, timeout_));
break;
case RDMA_CM_EVENT_ROUTE_RESOLVED:
printf("RDMA_CM_EVENT_ROUTE_RESOLVED \n");
TEST_NZ(rdma_connect(event_copy.id, &params));
break;
case RDMA_CM_EVENT_ESTABLISHED:
printf("RDMA_CM_EVENT_ESTABLISHED \n");
// none
break;
case RDMA_CM_EVENT_DISCONNECTED:
printf("RDMA_CM_EVENT_DISCONNECTED \n");
rdma_destroy_qp(event_copy.id);
rdma_destroy_id(event_copy.id);
return;
default:
DIE("unknown event");
}
}
}

void Write(char *data, int size)
{
}

private:
void RegisterCallback(const int bufferSize)
{
pre_conn_cb_ = [&](rdma_cm_id *id)
{
ClientConnectionContext *ctx = (ClientConnectionContext *)id->context;
posix_memalign((void **)&ctx->buffer, sysconf(_SC_PAGESIZE), bufferSize);
TEST_Z(ctx->buffer_mr = ibv_reg_mr(context_->pd, ctx->buffer, bufferSize, 0));
posix_memalign((void **)&ctx->msg, sysconf(_SC_PAGESIZE), sizeof(*ctx->msg));
TEST_Z(ctx->msg_mr = ibv_reg_mr(context_->pd, ctx->msg, sizeof(*ctx->msg), IBV_ACCESS_LOCAL_WRITE));
ClientPostReceive(id);
};

completion_cb_ = [](ibv_wc *wc)
{
rdma_cm_id *id = (rdma_cm_id *)(uintptr_t)(wc->wr_id);
ClientConnectionContext *ctx = (ClientConnectionContext *)id->context;
if (wc->opcode & IBV_WC_RECV)
{
if (ctx->msg->id == MSG_MR)
{
printf("received MR, sending file name\n");
ctx->peer_addr = ctx->msg->data.mr.addr;
ctx->peer_rkey = ctx->msg->data.mr.rkey;
memset(((ClientConnectionContext *)id->context)->buffer, 'a', 20);
((ClientConnectionContext *)id->context)->buffer[21] = '\0';
ClientWriteRemote(id, 21);
printf("received MR, sending file name\n");
}
else if (ctx->msg->id == MSG_READY)
{
printf("received READY, sending chunk\n");
memset(((ClientConnectionContext *)id->context)->buffer, 'a', 20);
((ClientConnectionContext *)id->context)->buffer[21] = '\0';
ClientWriteRemote(id, 21);
printf("received READY, sending chunk\n");
}
else if (ctx->msg->id == MSG_DONE)
{
printf("received DONE, disconnecting\n");
rdma_disconnect(id);
return;
}
ClientPostReceive(id);
}
};
}

void BuildQPAttribute(ibv_qp_init_attr *qp_attr)
{
memset(qp_attr, 0, sizeof(ibv_qp_init_attr));
qp_attr->send_cq = context_->cq;
qp_attr->recv_cq = context_->cq;
qp_attr->qp_type = IBV_QPT_RC;
qp_attr->cap.max_send_wr = 10;
qp_attr->cap.max_recv_wr = 10;
qp_attr->cap.max_send_sge = 1;
qp_attr->cap.max_recv_sge = 1;
}

void BuildContext(ibv_context *verbs)
{
if (context_)
{
if (context_->ctx != verbs)
DIE("cannot handle events in more than one context.");
return;
}

context_ = (Context *)malloc(sizeof(Context));
context_->ctx = verbs;
TEST_Z(context_->pd = ibv_alloc_pd(context_->ctx));
TEST_Z(context_->comp_channel = ibv_create_comp_channel(context_->ctx));
TEST_Z(context_->cq = ibv_create_cq(context_->ctx, 10, NULL, context_->comp_channel, 0)); /* cqe=10 is arbitrary */
TEST_NZ(ibv_req_notify_cq(context_->cq, 0));

// create poll thread
auto poller = [&]()
{
ibv_cq *cq;
ibv_wc wc;
void *ctx = NULL;
while (1)
{
TEST_NZ(ibv_get_cq_event(context_->comp_channel, &cq, &ctx));
ibv_ack_cq_events(cq, 1);
TEST_NZ(ibv_req_notify_cq(cq, 0));

while (ibv_poll_cq(cq, 1, &wc))
{
if (wc.status == IBV_WC_SUCCESS)
{
completion_cb_(&wc);
}
else
{
DIE("poll_cq: status is not IBV_WC_SUCCESS");
}
}
}
};
thread_pool_->AddJob(poller);
}

void InitConnection(rdma_cm_id *id)
{
ibv_qp_init_attr qp_attr;
BuildContext(id->verbs);
BuildQPAttribute(&qp_attr);
TEST_NZ(rdma_create_qp(id, context_->pd, &qp_attr));
}

std::function<void(rdma_cm_id *)> pre_conn_cb_;
std::function<void(ibv_wc *)> completion_cb_;
ThreadPool *thread_pool_;
// context
rdma_cm_id *speaker_;
rdma_event_channel *channel_;
Context *context_;
int timeout_;
int buffer_size_;
};
47 changes: 47 additions & 0 deletions include/dijiang/RdmaClientSocketHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

#include "Debug.hpp"
#include "Context.h"

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <rdma/rdma_cma.h>

static void ClientWriteRemote(struct rdma_cm_id *id, uint32_t len)
{
ClientConnectionContext *ctx = (ClientConnectionContext *)id->context;
ibv_send_wr wr, *bad_wr = NULL;
ibv_sge sge;
memset(&wr, 0, sizeof(wr));
wr.wr_id = (uintptr_t)id;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.send_flags = IBV_SEND_SIGNALED;
wr.imm_data = htonl(len);
wr.wr.rdma.remote_addr = ctx->peer_addr;
wr.wr.rdma.rkey = ctx->peer_rkey;
if (len)
{
wr.sg_list = &sge;
wr.num_sge = 1;
sge.addr = (uintptr_t)ctx->buffer;
sge.length = len;
sge.lkey = ctx->buffer_mr->lkey;
}
TEST_NZ(ibv_post_send(id->qp, &wr, &bad_wr));
}

static void ClientPostReceive(struct rdma_cm_id *id)
{
ClientConnectionContext *ctx = (ClientConnectionContext *)id->context;
ibv_recv_wr wr, *bad_wr = NULL;
ibv_sge sge;
memset(&wr, 0, sizeof(ibv_recv_wr));
wr.wr_id = (uintptr_t)id;
wr.sg_list = &sge;
wr.num_sge = 1;
sge.addr = (uintptr_t)ctx->msg;
sge.length = sizeof(*ctx->msg);
sge.lkey = ctx->msg_mr->lkey;
TEST_NZ(ibv_post_recv(id->qp, &wr, &bad_wr));
}
20 changes: 10 additions & 10 deletions include/dijiang/RdmaServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "Context.h"
#include "Debug.hpp"
#include "ThreadPool.hpp"
#include "RdmaSocketHelper.h"
#include "RdmaServerSocketHelper.h"

#include <netdb.h>
#include <unistd.h>
Expand Down Expand Up @@ -93,21 +93,21 @@ class RdmaServerSocket
completion_cb_ = [func](ibv_wc *wc)
{
rdma_cm_id *id = (rdma_cm_id *)(uintptr_t)wc->wr_id;
ConnectionContext *ctx = (ConnectionContext *)id->context;
ServerConnectionContext *ctx = (ServerConnectionContext *)id->context;
if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM)
{
uint32_t size = ntohl(wc->imm_data);
if (size == 0)
{
ctx->msg->id = MSG_DONE;
SendMessage(id);
ServerSendMessage(id);
}
else
{
func(ctx->buffer, size);
PostReceive(id);
ServerPostReceive(id);
ctx->msg->id = MSG_READY;
SendMessage(id);
ServerSendMessage(id);
}
}
};
Expand Down Expand Up @@ -183,27 +183,27 @@ class RdmaServerSocket
{
pre_conn_cb_ = [&, bufferSize](rdma_cm_id *id)
{
ConnectionContext *ctx = (ConnectionContext *)malloc(sizeof(ConnectionContext));
ServerConnectionContext *ctx = (ServerConnectionContext *)malloc(sizeof(ServerConnectionContext));
id->context = ctx;
posix_memalign((void **)&ctx->buffer, sysconf(_SC_PAGESIZE), bufferSize);
TEST_Z(ctx->buffer_mr = ibv_reg_mr(context_->pd, ctx->buffer, bufferSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE));
posix_memalign((void **)&ctx->msg, sysconf(_SC_PAGESIZE), sizeof(*ctx->msg));
TEST_Z(ctx->msg_mr = ibv_reg_mr(context_->pd, ctx->msg, sizeof(*ctx->msg), 0));
PostReceive(id);
ServerPostReceive(id);
};

connect_cb_ = [&](rdma_cm_id *id)
{
ConnectionContext *ctx = (ConnectionContext *)id->context;
ServerConnectionContext *ctx = (ServerConnectionContext *)id->context;
ctx->msg->id = MSG_MR;
ctx->msg->data.mr.addr = (uintptr_t)ctx->buffer_mr->addr;
ctx->msg->data.mr.rkey = ctx->buffer_mr->rkey;
SendMessage(id);
ServerSendMessage(id);
};

disconnect_cb_ = [](rdma_cm_id *id)
{
ConnectionContext *ctx = (ConnectionContext *)id->context;
ServerConnectionContext *ctx = (ServerConnectionContext *)id->context;
ibv_dereg_mr(ctx->buffer_mr);
ibv_dereg_mr(ctx->msg_mr);
free(ctx->buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <cstring>
#include <rdma/rdma_cma.h>

static void PostReceive(rdma_cm_id *id)
static void ServerPostReceive(rdma_cm_id *id)
{
ibv_recv_wr wr, *bad_wr = NULL;
memset(&wr, 0, sizeof(wr));
Expand All @@ -19,9 +19,9 @@ static void PostReceive(rdma_cm_id *id)
TEST_NZ(ibv_post_recv(id->qp, &wr, &bad_wr));
}

static void SendMessage(rdma_cm_id *id)
static void ServerSendMessage(rdma_cm_id *id)
{
ConnectionContext *ctx = (ConnectionContext *)id->context;
ServerConnectionContext *ctx = (ServerConnectionContext *)id->context;
ibv_send_wr wr, *bad_wr = NULL;
ibv_sge sge;
memset(&wr, 0, sizeof(wr));
Expand Down
Loading

0 comments on commit 432d8d7

Please sign in to comment.