Skip to content

Commit

Permalink
mpi_dp: add support for same proc communication
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Mar 19, 2022
1 parent 3fbf470 commit 1deeb7b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 73 deletions.
179 changes: 106 additions & 73 deletions source/adios2/toolkit/sst/dp/mpi_dp.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "sst_data.h"

Expand Down Expand Up @@ -35,7 +36,6 @@ typedef struct _Mpi_WSR_Stream
CP_PeerCohort PeerCohort;
int ReaderCohortSize;
struct _MpiReaderContactInfo *ReaderContactInfo;
MPI_Comm MpiComm;
char MpiPortName[MPI_MAX_PORT_NAME];
} * Mpi_WSR_Stream;

Expand Down Expand Up @@ -64,6 +64,7 @@ typedef struct _MpiReaderContactInfo
{
char *ContactString;
void *RS_Stream;
MPI_Comm MpiComm;
} * MpiReaderContactInfo;

typedef struct _MpiWriterContactInfo
Expand All @@ -82,6 +83,7 @@ typedef struct _MpiReadRequestMsg
void *RS_Stream;
int RequestingRank;
int NotifyCondition;
int PID;
} * MpiReadRequestMsg;

static FMField MpiReadRequestList[] = {
Expand All @@ -97,6 +99,8 @@ static FMField MpiReadRequestList[] = {
FMOffset(MpiReadRequestMsg, RequestingRank)},
{"NotifyCondition", "integer", sizeof(int),
FMOffset(MpiReadRequestMsg, NotifyCondition)},
{"PID", "integer", sizeof(int),
FMOffset(MpiReadRequestMsg, PID)},
{NULL, NULL, 0, 0}};

static FMStructDescRec MpiReadRequestStructs[] = {
Expand All @@ -111,7 +115,9 @@ typedef struct _MpiReadReplyMsg
void *RS_Stream;
int NotifyCondition;
char *MpiPortName;
char *Data;
int MpiTag;
int PID;
} * MpiReadReplyMsg;

static FMField MpiReadReplyList[] = {
Expand All @@ -124,6 +130,12 @@ static FMField MpiReadReplyList[] = {
FMOffset(MpiReadReplyMsg, NotifyCondition)},
{"MpiPortName", "string", sizeof(char *),
FMOffset(MpiReadReplyMsg, MpiPortName)},
{"Data", "char[DataLength]", sizeof(char),
FMOffset(MpiReadReplyMsg, Data)},
{"MpiTag", "integer", sizeof(int),
FMOffset(MpiReadReplyMsg, MpiTag)},
{"PID", "integer", sizeof(int),
FMOffset(MpiReadReplyMsg, PID)},
{NULL, NULL, 0, 0}};

static FMStructDescRec MpiReadReplyStructs[] = {
Expand All @@ -147,6 +159,8 @@ static FMField MpiWriterContactList[] = {
FMOffset(MpiWriterContactInfo, ContactString)},
{"writer_ID", "integer", sizeof(void *),
FMOffset(MpiWriterContactInfo, WS_Stream)},
{"MpiComm", "integer", sizeof(int),
FMOffset(MpiWriterContactInfo, MpiComm)},
{NULL, NULL, 0, 0}};

static FMStructDescRec MpiWriterContactStructs[] = {
Expand All @@ -159,7 +173,7 @@ static void MpiReadReplyHandler(CManager cm, CMConnection conn, void *msg_v,

static void Initialize_MPI()
{
int IsInitialized = 0;
static int IsInitialized = 0;
int provided;

MPI_Initialized(&IsInitialized);
Expand Down Expand Up @@ -278,6 +292,12 @@ static void MpiReadRequestHandler(CManager cm, CMConnection conn, void *msg_v,
ReadReplyMsg.RS_Stream = ReadRequestMsg->RS_Stream;
ReadReplyMsg.NotifyCondition = ReadRequestMsg->NotifyCondition;
ReadReplyMsg.MpiPortName = WSR_Stream->MpiPortName;
ReadReplyMsg.PID = getpid();

if (ReadRequestMsg->PID == getpid())
{
ReadReplyMsg.Data = tmp->Data->block + ReadRequestMsg->Offset;
}

Svcs->verbose(WS_Stream->CP_Stream, DPTraceVerbose,
"Sending a reply to reader rank %d for remote memory read,"
Expand All @@ -288,25 +308,29 @@ static void MpiReadRequestHandler(CManager cm, CMConnection conn, void *msg_v,
ReadRequestMsg->RequestingRank, WS_Stream->ReadReplyFormat,
&ReadReplyMsg);

// Send the actual Data using MPI
int worldErrHandler;
MPI_Comm_get_errhandler(MPI_COMM_WORLD, &worldErrHandler);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
int ret = MPI_Send(outboundBuffer, ReadRequestMsg->Length, MPI_CHAR, 0,
ReadRequestMsg->NotifyCondition, WSR_Stream->MpiComm);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, worldErrHandler);

if (ret != MPI_SUCCESS)
if (ReadRequestMsg->PID != getpid())
{
WSR_Stream->MpiComm = 0;
MPI_Comm_accept(WSR_Stream->MpiPortName, MPI_INFO_NULL, 0, MPI_COMM_SELF,
&WSR_Stream->MpiComm);
MPI_Send(outboundBuffer, ReadRequestMsg->Length, MPI_CHAR, 0,
ReadRequestMsg->NotifyCondition, WSR_Stream->MpiComm);
// Send the actual Data using MPI
MPI_Comm* comm = &WSR_Stream->ReaderContactInfo[ReadRequestMsg->RequestingRank].MpiComm;
int worldErrHandler;
MPI_Comm_get_errhandler(MPI_COMM_WORLD, &worldErrHandler);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
int ret = MPI_Send(outboundBuffer, ReadRequestMsg->Length, MPI_CHAR, 0,
ReadRequestMsg->NotifyCondition, *comm);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, worldErrHandler);

if (ret != MPI_SUCCESS)
{
MPI_Comm_accept(WSR_Stream->MpiPortName, MPI_INFO_NULL, 0, MPI_COMM_SELF,
comm);
MPI_Send(outboundBuffer, ReadRequestMsg->Length, MPI_CHAR, 0,
ReadRequestMsg->NotifyCondition, *comm);
}
}

Svcs->verbose(WS_Stream->CP_Stream, DPTraceVerbose,
"MPI_DP: Connected to client\n");
"MPI_DP: Connected to client, num of clients=%d\n",
WSR_Stream->ReaderCohortSize);

free(outboundBuffer);

Expand Down Expand Up @@ -338,36 +362,44 @@ static void MpiReadReplyHandler(CManager cm, CMConnection conn, void *msg_v,
Handle->Rank, ReadReplyMsg->NotifyCondition,
ReadReplyMsg->DataLength);

/*
* `Handle` contains the full request info and is `client_data`
* associated with the CMCondition. Once we get it, MPI copy the incoming
* data to the buffer area given by the request
*/
pthread_mutex_lock(&mpi_comm_mutex);
MPI_Comm comm = RS_Stream->WriterContactInfo[Handle->Rank].MpiComm;
pthread_mutex_unlock(&mpi_comm_mutex);

int worldErrHandler;
MPI_Comm_get_errhandler(MPI_COMM_WORLD, &worldErrHandler);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
int ret = MPI_Recv(Handle->Buffer, ReadReplyMsg->DataLength, MPI_CHAR, 0,
ReadReplyMsg->NotifyCondition, comm, MPI_STATUS_IGNORE);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, worldErrHandler);

if (ret != MPI_SUCCESS)
if(ReadReplyMsg->PID == getpid())
{
memcpy(Handle->Buffer, ReadReplyMsg->Data, ReadReplyMsg->DataLength);
}
else
{
MPI_Comm_connect(ReadReplyMsg->MpiPortName, MPI_INFO_NULL, 0, MPI_COMM_SELF,
&comm);

Svcs->verbose(RS_Stream->CP_Stream, DPTraceVerbose,
"MPI_DP: Connecting to MPI Server\n");
MPI_Recv(Handle->Buffer, ReadReplyMsg->DataLength, MPI_CHAR, 0,
/*
* `Handle` contains the full request info and is `client_data`
* associated with the CMCondition. Once we get it, MPI copy the incoming
* data to the buffer area given by the request
*/
pthread_mutex_lock(&mpi_comm_mutex);
MPI_Comm comm = RS_Stream->WriterContactInfo[Handle->Rank].MpiComm;
pthread_mutex_unlock(&mpi_comm_mutex);

int worldErrHandler;
MPI_Comm_get_errhandler(MPI_COMM_WORLD, &worldErrHandler);
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
int ret = MPI_Recv(Handle->Buffer, ReadReplyMsg->DataLength, MPI_CHAR, 0,
ReadReplyMsg->NotifyCondition, comm, MPI_STATUS_IGNORE);
}
MPI_Comm_set_errhandler(MPI_COMM_WORLD, worldErrHandler);

pthread_mutex_lock(&mpi_comm_mutex);
RS_Stream->WriterContactInfo[Handle->Rank].MpiComm = comm;
pthread_mutex_unlock(&mpi_comm_mutex);
if (ret != MPI_SUCCESS)
{
MPI_Comm_connect(ReadReplyMsg->MpiPortName, MPI_INFO_NULL, 0, MPI_COMM_SELF,
&comm);

Svcs->verbose(RS_Stream->CP_Stream, DPTraceVerbose,
"MPI_DP: Connecting to MPI Server\n");
MPI_Recv(Handle->Buffer, ReadReplyMsg->DataLength, MPI_CHAR, 0,
ReadReplyMsg->NotifyCondition, comm, MPI_STATUS_IGNORE);
}

pthread_mutex_lock(&mpi_comm_mutex);
RS_Stream->WriterContactInfo[Handle->Rank].MpiComm = comm;
pthread_mutex_unlock(&mpi_comm_mutex);
}

/*
* Signal the condition to wake the reader if they are waiting.
Expand Down Expand Up @@ -591,6 +623,7 @@ static void *MpiReadRemoteMemory(CP_Services Svcs, DP_RS_Stream Stream_v,
ReadRequestMsg.RS_Stream = Stream;
ReadRequestMsg.RequestingRank = Stream->Rank;
ReadRequestMsg.NotifyCondition = ret->CMcondition;
ReadRequestMsg.PID = getpid();

Svcs->sendToPeer(Stream->CP_Stream, Stream->PeerCohort, Rank,
Stream->ReadRequestFormat, &ReadRequestMsg);
Expand Down Expand Up @@ -729,8 +762,6 @@ static void MpiReleaseTimestep(CP_Services Svcs, DP_WS_Stream Stream_v,
pthread_mutex_unlock(&ts_mutex);
}

static struct _CP_DP_Interface mpiDPInterface;

static int MpiGetPriority(CP_Services Svcs, void *CP_Stream,
struct _SstParams *Params)
{
Expand All @@ -752,10 +783,18 @@ static void MpiNotifyConnFailure(CP_Services Svcs, DP_RS_Stream Stream_v,
static void MpiDestroyWriterPerReader(CP_Services Svcs,
DP_WSR_Stream WSR_Stream_v)
{
Mpi_WSR_Stream WSR_Stream = {0};
memcpy(&WSR_Stream, &WSR_Stream_v, sizeof(Mpi_WSR_Stream));
Mpi_WSR_Stream WSR_Stream = (Mpi_WSR_Stream)WSR_Stream_v;
Mpi_WS_Stream WS_Stream = WSR_Stream->WS_Stream;
MpiWriterContactInfo WriterContactInfo = {0};

MPI_Close_port(WSR_Stream->MpiPortName);

for (int i = 0; i < WSR_Stream->ReaderCohortSize; i++)
{
if (WSR_Stream->ReaderContactInfo[i].MpiComm)
{
MPI_Comm_disconnect(&WSR_Stream->ReaderContactInfo[i].MpiComm);
}
}

pthread_mutex_lock(&ws_mutex);
for (int i = 0; i < WS_Stream->ReaderCount; i++)
Expand All @@ -768,23 +807,16 @@ static void MpiDestroyWriterPerReader(CP_Services Svcs,
}
}

if (WSR_Stream->MpiComm)
{
MPI_Comm_disconnect(&WSR_Stream->MpiComm);
}
MPI_Close_port(WSR_Stream->MpiPortName);

if (WSR_Stream->ReaderContactInfo)
{
free(WSR_Stream->ReaderContactInfo);
}
WS_Stream->Readers = realloc(
WS_Stream->Readers, sizeof(*WSR_Stream) * (WS_Stream->ReaderCount - 1));
WS_Stream->ReaderCount--;
free(WSR_Stream);

pthread_mutex_unlock(&ws_mutex);

free(WSR_Stream);
}

static void MpiDestroyWriter(CP_Services Svcs, DP_WS_Stream WS_Stream_v)
Expand Down Expand Up @@ -818,23 +850,24 @@ static void MpiDestroyReader(CP_Services Svcs, DP_RS_Stream RS_Stream_v)

extern CP_DP_Interface LoadMpiDP()
{
memset(&mpiDPInterface, 0, sizeof(mpiDPInterface));
mpiDPInterface.ReaderContactFormats = MpiReaderContactStructs;
mpiDPInterface.WriterContactFormats = MpiWriterContactStructs;
mpiDPInterface.TimestepInfoFormats = NULL;
mpiDPInterface.initReader = MpiInitReader;
mpiDPInterface.initWriter = MpiInitWriter;
mpiDPInterface.initWriterPerReader = MpiInitWriterPerReader;
mpiDPInterface.provideWriterDataToReader = MpiProvideWriterDataToReader;
mpiDPInterface.readRemoteMemory = MpiReadRemoteMemory;
mpiDPInterface.waitForCompletion = MpiWaitForCompletion;
mpiDPInterface.provideTimestep = MpiProvideTimestep;
mpiDPInterface.releaseTimestep = MpiReleaseTimestep;
mpiDPInterface.getPriority = MpiGetPriority;
mpiDPInterface.destroyReader = MpiDestroyReader;
mpiDPInterface.destroyWriter = MpiDestroyWriter;
mpiDPInterface.destroyWriterPerReader = MpiDestroyWriterPerReader;
mpiDPInterface.notifyConnFailure = MpiNotifyConnFailure;
static struct _CP_DP_Interface mpiDPInterface = {
.ReaderContactFormats = MpiReaderContactStructs,
.WriterContactFormats = MpiWriterContactStructs,
.TimestepInfoFormats = NULL,
.initReader = MpiInitReader,
.initWriter = MpiInitWriter,
.initWriterPerReader = MpiInitWriterPerReader,
.provideWriterDataToReader = MpiProvideWriterDataToReader,
.readRemoteMemory = MpiReadRemoteMemory,
.waitForCompletion = MpiWaitForCompletion,
.provideTimestep = MpiProvideTimestep,
.releaseTimestep = MpiReleaseTimestep,
.getPriority = MpiGetPriority,
.destroyReader = MpiDestroyReader,
.destroyWriter = MpiDestroyWriter,
.destroyWriterPerReader = MpiDestroyWriterPerReader,
.notifyConnFailure = MpiNotifyConnFailure,
};

return &mpiDPInterface;
}
1 change: 1 addition & 0 deletions testing/adios2/engine/staging-common/TestThreads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ TEST_F(TestThreads, Basic)
auto read_fut = std::async(std::launch::async, Read, BaseName, 0);
auto write_fut = std::async(std::launch::async, Write, BaseName, 0);
bool reader_success = read_fut.get();
sleep(1);
bool writer_success = write_fut.get();
EXPECT_TRUE(reader_success);
EXPECT_TRUE(writer_success);
Expand Down

0 comments on commit 1deeb7b

Please sign in to comment.