Commit a099eeed authored by Philipp Samfass's avatar Philipp Samfass
Browse files

added support for custom MPI communicators + added support for MPI_Iallgather/MPI_Allgather

parent 0808400e
......@@ -88,8 +88,8 @@ int getNumberOfTeams() {
return getWorldSize() / getTeamSize();
}
MPI_Comm getTeamComm() {
return TMPI_COMM_TEAM;
MPI_Comm getTeamComm(MPI_Comm comm) {
return (comm==MPI_COMM_WORLD) ? TMPI_COMM_TEAM : comm;
}
int freeTeamComm() {
......@@ -175,7 +175,7 @@ void remapStatus(MPI_Status *status) {
}
int synchroniseRanksInTeam() {
return PMPI_Barrier(getTeamComm());
return PMPI_Barrier(getTeamComm(MPI_COMM_WORLD));
}
int synchroniseRanksGlobally() {
......
......@@ -2,7 +2,7 @@
* RankOperations.h
*
* Created on: 2 Mar 2018
* Author: Ben Hazelwood
* Author: Ben Hazelwood, Philipp Samfass
*/
#ifndef RANK_H_
......@@ -50,7 +50,7 @@ int getNumberOfTeams();
int getTeam();
/* The communicator used by this team */
MPI_Comm getTeamComm();
MPI_Comm getTeamComm(MPI_Comm comm);
int freeTeamComm();
/* The duplicate MPI_COMM_WORLD used by the library*/
......
......@@ -26,45 +26,79 @@ int MPI_Is_thread_main(int* flag) {
}
int MPI_Comm_rank(MPI_Comm comm, int *rank) {
assert(comm == MPI_COMM_WORLD);
// todo: assert that a team comm is used
//assert(comm == MPI_COMM_WORLD);
*rank = getTeamRank();
return MPI_SUCCESS;
}
int MPI_Comm_size(MPI_Comm comm, int *size) {
assert(comm == MPI_COMM_WORLD);
//assert(comm == MPI_COMM_WORLD);
*size = getTeamSize();
return MPI_SUCCESS;
}
int MPI_Comm_dup(MPI_Comm comm, MPI_Comm *newcomm) {
int err = PMPI_Comm_dup(getTeamComm(comm), newcomm);
logInfo("Created communicator " << *newcomm);
assert(err==MPI_SUCCESS);
return err;
}
int MPI_Comm_free(MPI_Comm *comm) {
assert(*comm != MPI_COMM_WORLD);
logInfo("Free communicator " << *comm);
int err = PMPI_Comm_free(comm);
assert(err==MPI_SUCCESS);
return err;
}
int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
int tag, MPI_Comm comm) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm());
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm(comm));
logInfo("Send to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
return err;
}
int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
MPI_Comm comm, MPI_Status *status) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(), status);
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(comm), status);
logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
return err;
}
int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype,
MPI_Comm comm) {
int err = PMPI_Allgather(sendbuf, sendcount, sendtype,
recvbuf, recvcount, recvtype,
getTeamComm(comm));
return err;
}
int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype,
MPI_Comm comm, MPI_Request * request) {
int err = PMPI_Iallgather(sendbuf, sendcount, sendtype,
recvbuf, recvcount, recvtype,
getTeamComm(comm), request);
return err;
}
int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
int tag, MPI_Comm comm, MPI_Request *request) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(), request);
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(comm), request);
logInfo("Isend to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
return err;
}
int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
MPI_Comm comm, MPI_Request *request) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Irecv(buf, count, datatype, source, tag, getTeamComm(), request);
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Irecv(buf, count, datatype, source, tag, getTeamComm(comm), request);
logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
return err;
}
......@@ -90,46 +124,46 @@ int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) {
}
int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
assert(comm == MPI_COMM_WORLD);
//assert(comm == MPI_COMM_WORLD);
logInfo("Probe initialised (SOURCE=" << source << ",TAG=" << tag << ")");
int err = PMPI_Probe(source, tag, getTeamComm(), status);
int err = PMPI_Probe(source, tag, getTeamComm(comm), status);
logInfo("Probe finished (SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
return err;
}
int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
MPI_Status *status) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Iprobe(source, tag, getTeamComm(), flag, status);
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Iprobe(source, tag, getTeamComm(comm), flag, status);
logInfo("Iprobe finished (FLAG=" << *flag << ",SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
return err;
}
int MPI_Barrier(MPI_Comm comm) {
assert(comm == MPI_COMM_WORLD);
//assert(comm == MPI_COMM_WORLD);
int err = synchroniseRanksInTeam();
return err;
}
int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
MPI_Comm comm ) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Bcast(buffer, count, datatype, root, getTeamComm());
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Bcast(buffer, count, datatype, root, getTeamComm(comm));
return err;
}
int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getTeamComm());
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getTeamComm(comm));
return err;
}
int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype,
MPI_Comm comm) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getTeamComm());
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getTeamComm(comm));
return err;
}
......@@ -137,8 +171,8 @@ int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
const int *sdispls, MPI_Datatype sendtype, void *recvbuf,
const int *recvcounts, const int *rdispls, MPI_Datatype recvtype,
MPI_Comm comm) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getTeamComm());
//assert(comm == MPI_COMM_WORLD);
int err = PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getTeamComm(comm));
return err;
}
......@@ -159,8 +193,8 @@ int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
Timing::markTimeline(sendtag, sendbuf, sendcount, sendtype);
}
} else {
assert(comm == MPI_COMM_WORLD);
MPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getTeamComm(),status);
//assert(comm == MPI_COMM_WORLD);
MPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getTeamComm(comm),status);
}
return MPI_SUCCESS;
}
......@@ -177,6 +211,6 @@ int MPI_Finalize() {
int MPI_Abort(MPI_Comm comm, int errorcode) {
assert(comm == MPI_COMM_WORLD);
int err = PMPI_Abort(getTeamComm(), errorcode);
int err = PMPI_Abort(getTeamComm(comm), errorcode);
return err;
}
......@@ -43,6 +43,20 @@ int MPI_Comm_rank(MPI_Comm comm, int *rank);
*/
int MPI_Comm_size(MPI_Comm comm, int *size);
/**
*
* @param comm
* @param newcomm
* @return
*/
int MPI_Comm_dup(MPI_Comm comm, MPI_Comm *newcomm);
/**
*
* @param comm
* @return
*/
int MPI_Comm_free(MPI_Comm *comm);
/**
* Sends only to the corresponding replica of dest
......@@ -83,6 +97,14 @@ int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status);
int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
MPI_Status *status);
int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype,
MPI_Comm comm);
int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype,
MPI_Comm comm, MPI_Request *request);
int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status);
int MPI_Barrier(MPI_Comm comm);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment