05.11., 9:00 - 11:00: Due to updates GitLab may be unavailable for some minutes between 09:00 and 11:00.

Wrapper.cpp 7.97 KB
Newer Older
1
#include "Wrapper.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
2 3

#include <cassert>
4
#include <unistd.h>
Ben Hazelwood's avatar
Ben Hazelwood committed
5

6
#include "Logging.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
7
#include "Rank.h"
8
#include "Timing.h"
9
#include "CommStats.h"
10 11

int MPI_Init(int *argc, char*** argv) {
12
  int err = PMPI_Init(argc, argv);
Ben Hazelwood's avatar
Ben Hazelwood committed
13
  initialiseTMPI();
14
  return err;
15 16 17
}

int MPI_Init_thread( int *argc, char ***argv, int required, int *provided ) {
18
  int err = PMPI_Init_thread(argc, argv, required, provided);
Ben Hazelwood's avatar
Ben Hazelwood committed
19
  initialiseTMPI();
20
  return err;
21 22 23 24
}

int MPI_Is_thread_main(int* flag) {
  // See header documentation
Ben Hazelwood's avatar
Ben Hazelwood committed
25
  *flag = getTeam();
26 27 28 29
  return MPI_SUCCESS;
}

int MPI_Comm_rank(MPI_Comm comm, int *rank) {
30 31
  // todo: assert that a team comm is used
  //assert(comm == MPI_COMM_WORLD);
32 33 34 35
  if(comm==MPI_COMM_WORLD) 
   *rank = getTeamRank();
  else 
   PMPI_Comm_rank(comm, rank);
36 37 38 39
  return MPI_SUCCESS;
}

int MPI_Comm_size(MPI_Comm comm, int *size) {
40
  //assert(comm == MPI_COMM_WORLD);
41 42 43 44
  *size = getTeamSize();
  return MPI_SUCCESS;
}

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
int MPI_Comm_dup(MPI_Comm comm, MPI_Comm *newcomm) {
  int err = PMPI_Comm_dup(getTeamComm(comm), newcomm);
  logInfo("Created communicator " << *newcomm);
  assert(err==MPI_SUCCESS);
  return err;
}

int MPI_Comm_free(MPI_Comm *comm) {
  assert(*comm != MPI_COMM_WORLD);
  logInfo("Free communicator " << *comm);
  int err = PMPI_Comm_free(comm);
  assert(err==MPI_SUCCESS);
  return err;
}

60 61
int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
              int tag, MPI_Comm comm) {
62 63 64
#if COMM_STATS
  CommunicationStatistics::trackSend(datatype,  count);
#endif
65 66
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm(comm));
67
  logInfo("Send to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
68 69 70 71 72
  return err;
}

int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
              MPI_Comm comm, MPI_Status *status) {
73
  //assert(comm == MPI_COMM_WORLD);
74 75 76
#if COMM_STATS
  CommunicationStatistics::trackReceive(datatype,  count);
#endif
77
  int err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(comm), status);
78
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
79 80 81
  return err;
}

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                  void *recvbuf, int recvcount, MPI_Datatype recvtype,
                  MPI_Comm comm) {
  int err = PMPI_Allgather(sendbuf, sendcount, sendtype,
                           recvbuf, recvcount, recvtype,
                           getTeamComm(comm)); 
  return err;
}

int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                   void *recvbuf, int recvcount, MPI_Datatype recvtype,
                   MPI_Comm comm, MPI_Request * request) {
  int err = PMPI_Iallgather(sendbuf, sendcount, sendtype,
                            recvbuf, recvcount, recvtype,
                            getTeamComm(comm), request);
  return err;
}

100 101
int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
               int tag, MPI_Comm comm, MPI_Request *request) {
102
  //assert(comm == MPI_COMM_WORLD);
103 104 105
#if COMM_STATS
  CommunicationStatistics::trackSend(datatype,  count);
#endif
106
  int err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(comm), request);
107
  logInfo("Isend to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
108 109 110 111 112
  return err;
}

int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
               MPI_Comm comm, MPI_Request *request) {
113
  //assert(comm == MPI_COMM_WORLD);
114 115 116
#if COMM_STATS
  CommunicationStatistics::trackReceive(datatype,  count);
#endif
117
  int err = PMPI_Irecv(buf, count, datatype, source, tag, getTeamComm(comm), request);
118
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
119 120 121 122 123
  return err;
}

int MPI_Wait(MPI_Request *request, MPI_Status *status) {
  logInfo("Wait initialised");
124 125
  int err = PMPI_Wait(request, status);
  logInfo("Wait completed");
126 127 128
  return err;
}

129 130
int MPI_Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]) {
  logInfo("Waitall initialised with " << count << " requests");
131
  int err = PMPI_Waitall(count, array_of_requests, array_of_statuses);
132 133 134 135
  logInfo("Waitall completed with " << count << " requests");
  return err;
}

136
int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) {
137 138
  int err = PMPI_Test(request, flag, status);
  logInfo("Test completed (FLAG=" << *flag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
139 140 141 142
  return err;
}

int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
143
  //assert(comm == MPI_COMM_WORLD);
144
  logInfo("Probe initialised (SOURCE=" << source << ",TAG=" << tag << ")");
145
  int err = PMPI_Probe(source, tag, getTeamComm(comm), status);
146
  logInfo("Probe finished (SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
147 148 149 150 151
  return err;
}

int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
                MPI_Status *status) {
152 153
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Iprobe(source, tag, getTeamComm(comm), flag, status);
154
  logInfo("Iprobe finished (FLAG=" << *flag << ",SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
155 156 157 158
  return err;
}

int MPI_Barrier(MPI_Comm comm) {
159
  //assert(comm == MPI_COMM_WORLD);
160
  int err = synchroniseRanksInTeam();
161 162 163
  return err;
}

Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
164 165
int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
               MPI_Comm comm ) {
166 167
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Bcast(buffer, count, datatype, root, getTeamComm(comm));
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
168 169 170 171 172
  return err;
}

int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
                  MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
173 174
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getTeamComm(comm));
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
175 176 177 178 179 180
  return err;
}

int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                 void *recvbuf, int recvcount, MPI_Datatype recvtype,
                 MPI_Comm comm) {
181 182
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getTeamComm(comm));
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
183 184 185 186 187 188 189
  return err;
}

int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
                  const int *sdispls, MPI_Datatype sendtype, void *recvbuf,
                  const int *recvcounts, const int *rdispls, MPI_Datatype recvtype,
                  MPI_Comm comm) {
190 191
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getTeamComm(comm));
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
192 193 194
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
195
double MPI_Wtime() {
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
196
  // If you mark on the timeline here expect negative time values (you've been warned)
197
  return PMPI_Wtime();
Ben Hazelwood's avatar
Ben Hazelwood committed
198 199
}

200 201 202 203 204 205
int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                int dest, int sendtag,
                void *recvbuf, int recvcount, MPI_Datatype recvtype,
                int source, int recvtag,
                MPI_Comm comm, MPI_Status *status) {
  if (comm == MPI_COMM_SELF) {
206
    if (sendcount == 0) {
207
      Timing::markTimeline(sendtag);
208
    } else {
209
      Timing::markTimeline(sendtag, sendbuf, sendcount, sendtype);
210
    }
211
  } else {
212
    //assert(comm == MPI_COMM_WORLD);
213
    PMPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getTeamComm(comm),status);
214 215 216 217
  }
  return MPI_SUCCESS;
}

218
int MPI_Finalize() {
219
  logInfo("Finalize");
220
  Timing::finaliseTiming();
221 222
  // Wait for all replicas before finalising
  PMPI_Barrier(MPI_COMM_WORLD);
Ben Hazelwood's avatar
Ben Hazelwood committed
223
  freeTeamComm();
Ben Hazelwood's avatar
Ben Hazelwood committed
224
  Timing::outputTiming();
225 226 227
#if COMM_STATS
  CommunicationStatistics::outputCommunicationStatistics();
#endif
228 229 230
#ifdef DirtyCleanUp
  return MPI_SUCCESS;
#endif
231 232
  return PMPI_Finalize();
}
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
233 234 235

int MPI_Abort(MPI_Comm comm, int errorcode) {
  assert(comm == MPI_COMM_WORLD);
236
  int err = PMPI_Abort(getTeamComm(comm), errorcode);
Benjamin Hazelwood's avatar
Benjamin Hazelwood committed
237 238
  return err;
}