Wrapper.cpp 11 KB
Newer Older
1
#include "Wrapper.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
2
3

#include <cassert>
4
#include <unistd.h>
Ben Hazelwood's avatar
Ben Hazelwood committed
5

6
#include "Logging.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
7
#include "Rank.h"
8
9
#include "Timing.h"

10
11
12
13
14
15

#ifdef USE_MPI_OFFLOADING
#include "mpi_offloading.h"
#include "mpi_offloading_common.h"
#endif

16

17
int MPI_Init(int *argc, char*** argv) {
18
  int err = PMPI_Init(argc, argv);
19
#ifdef USE_MPI_OFFLOADING
Philipp Samfaß's avatar
Philipp Samfaß committed
20
  smpi_init( (*argv)[0]);
21
22
23
24
25
  if(!_is_server)
    initialiseTMPI(_comm);
#else
  initialiseTMPI(MPI_COMM_WORLD);
#endif
26
  return err;
27
28
29
}

int MPI_Init_thread( int *argc, char ***argv, int required, int *provided ) {
30
  int err = PMPI_Init_thread(argc, argv, required, provided);
31
#ifdef USE_MPI_OFFLOADING
Philipp Samfaß's avatar
Philipp Samfaß committed
32
  smpi_init( (*argv)[0]);
33
34
35
36
37
  if(!_is_server)
    initialiseTMPI(_comm);
#else
  initialiseTMPI(MPI_COMM_WORLD);
#endif
38
  return err;
39
40
41
42
}

int MPI_Is_thread_main(int* flag) {
  // See header documentation
Ben Hazelwood's avatar
Ben Hazelwood committed
43
  *flag = getTeam();
44
45
46
  return MPI_SUCCESS;
}

Philipp Samfaß's avatar
Philipp Samfaß committed
47
48
49
50
51
52
53
54
55
56
int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm *newcomm) {
  if(comm==MPI_COMM_WORLD)
#ifdef USE_MPI_OFFLOADING
    PMPI_Comm_split(_comm, color, key, newcomm);
#else
    PMPI_Comm_split(comm, color, key, newcomm);
#endif
  return MPI_SUCCESS; 
}

57
int MPI_Comm_rank(MPI_Comm comm, int *rank) {
58
  // todo: assert that a team comm is used
59
60
  //assert(comm == MPI_COMM_WORLD);  

61
  if(comm==MPI_COMM_WORLD) 
62
63
64
65
66
67
68
69
70
#ifdef USE_MPI_OFFLOADING
    if(!_is_server) 
     *rank = getTeamRank();
    else {
     PMPI_Comm_rank(_comm, rank);
    }
#else
     *rank = getTeamRank();
#endif
71
72
  else 
   PMPI_Comm_rank(comm, rank);
73
74
75
76
  return MPI_SUCCESS;
}

int MPI_Comm_size(MPI_Comm comm, int *size) {
77
  //assert(comm == MPI_COMM_WORLD);
78
79
80
81
  *size = getTeamSize();
  return MPI_SUCCESS;
}

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
int MPI_Comm_dup(MPI_Comm comm, MPI_Comm *newcomm) {
  int err = PMPI_Comm_dup(getTeamComm(comm), newcomm);
  logInfo("Created communicator " << *newcomm);
  assert(err==MPI_SUCCESS);
  return err;
}

int MPI_Comm_free(MPI_Comm *comm) {
  assert(*comm != MPI_COMM_WORLD);
  logInfo("Free communicator " << *comm);
  int err = PMPI_Comm_free(comm);
  assert(err==MPI_SUCCESS);
  return err;
}

97
98
int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
              int tag, MPI_Comm comm) {
99
  //assert(comm == MPI_COMM_WORLD);
100
101
102
103
104
105
106
107
108
109
110
111
  int err;
#ifdef USE_MPI_OFFLOADING
  if(_is_server) {
    MPI_Comm mapped_comm = (comm==MPI_COMM_WORLD) ? _comm : comm;
    err = PMPI_Send(buf, count, datatype, dest, tag, mapped_comm);
  }
  else {
    err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm(comm));  
  }
#else
  err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm(comm));  
#endif
112
  logInfo("Send to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
113
114
115
116
117
  return err;
}

int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
              MPI_Comm comm, MPI_Status *status) {
118
  //assert(comm == MPI_COMM_WORLD);
119
120
121
122
123
124
125
126
127
128
129
130
  int err;
#ifdef USE_MPI_OFFLOADING  
  if(_is_server) {
    MPI_Comm mapped_comm = (comm==MPI_COMM_WORLD) ? _comm : comm;
    err = PMPI_Recv(buf, count, datatype, source, tag, mapped_comm, status);
  }
  else {
    err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(comm), status);     
  }
#else
   err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(comm), status);
#endif
131
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
132
133
134
  return err;
}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                  void *recvbuf, int recvcount, MPI_Datatype recvtype,
                  MPI_Comm comm) {
  int err = PMPI_Allgather(sendbuf, sendcount, sendtype,
                           recvbuf, recvcount, recvtype,
                           getTeamComm(comm)); 
  return err;
}

int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                   void *recvbuf, int recvcount, MPI_Datatype recvtype,
                   MPI_Comm comm, MPI_Request * request) {
  int err = PMPI_Iallgather(sendbuf, sendcount, sendtype,
                            recvbuf, recvcount, recvtype,
                            getTeamComm(comm), request);
  return err;
}

153
154
int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
               int tag, MPI_Comm comm, MPI_Request *request) {
155
  //assert(comm == MPI_COMM_WORLD);
156
157
158
159
160
161
162
163
164
165
166
  int err;
#ifdef USE_MPI_OFFLOADING
  if(_is_server) {
    MPI_Comm mapped_comm = (comm==MPI_COMM_WORLD) ? _comm : comm;
    err = PMPI_Isend(buf, count, datatype, dest, tag, mapped_comm, request);
  }
  else {
    err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(comm), request);  
  }
#else
  err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(comm), request);
167
  logInfo("Isend to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
168
#endif 
169
170
171
172
173
  return err;
}

int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
               MPI_Comm comm, MPI_Request *request) {
174
175
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Irecv(buf, count, datatype, source, tag, getTeamComm(comm), request);
176
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
177
178
179
180
181
  return err;
}

int MPI_Wait(MPI_Request *request, MPI_Status *status) {
  logInfo("Wait initialised");
182
183
  int err = PMPI_Wait(request, status);
  logInfo("Wait completed");
184
185
186
  return err;
}

187
188
int MPI_Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]) {
  logInfo("Waitall initialised with " << count << " requests");
189
  int err = PMPI_Waitall(count, array_of_requests, array_of_statuses);
190
191
192
193
  logInfo("Waitall completed with " << count << " requests");
  return err;
}

194
int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) {
195
196
  int err = PMPI_Test(request, flag, status);
  logInfo("Test completed (FLAG=" << *flag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
197
198
199
200
  return err;
}

int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
201
  //assert(comm == MPI_COMM_WORLD);
202
  logInfo("Probe initialised (SOURCE=" << source << ",TAG=" << tag << ")");
203
  int err = PMPI_Probe(source, tag, getTeamComm(comm), status);
204
  logInfo("Probe finished (SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
205
206
207
208
209
  return err;
}

int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
                MPI_Status *status) {
210
  //assert(comm == MPI_COMM_WORLD);
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  int err;
#ifdef USE_MPI_OFFLOADING
  if(_is_server) {
    if(comm==MPI_COMM_WORLD)
      err = PMPI_Iprobe(source, tag, _comm, flag, status);
    else 
      err = PMPI_Iprobe(source, tag, comm, flag, status);      
  }
  else {
    err = PMPI_Iprobe(source, tag, getTeamComm(comm), flag, status);
  }
#else  
    err = PMPI_Iprobe(source, tag, getTeamComm(comm), flag, status);
#endif
225
  logInfo("Iprobe finished (FLAG=" << *flag << ",SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
226
227
228
  return err;
}

229
230
231
232
233
234
#ifdef USE_MPI_OFFLOADING

int MPI_Iprobe_offload(int source, int tag, MPI_Comm comm, int *flag, MPI_Status_Offload *status) {
  int ierr;
  assert(source==MPI_ANY_SOURCE);
#if COMMUNICATION_MODE==0
Philipp Samfaß's avatar
Philipp Samfaß committed
235
  ierr = smpi_iprobe_offload_p2p(source, tag, comm, flag, status);
236
#elif COMMUNICATION_MODE==1
Philipp Samfaß's avatar
Philipp Samfaß committed
237
238
  //ierr = smpi_iprobe_offload_rma(source, tag, comm, flag, status);
  ierr = smpi_iprobe_offload_p2p(source, tag, comm, flag, status);
239
240
241
242
243
#endif
  status->MPI_SOURCE = mapWorldToTeamRank(status->MPI_SOURCE);
  return ierr;
}

Philipp Samfass's avatar
Philipp Samfass committed
244
int MPI_Send_offload(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, int rail) {
245
246

#if COMMUNICATION_MODE==0
Philipp Samfaß's avatar
Philipp Samfaß committed
247
  return smpi_send_offload_p2p(buf, count, datatype, translateRank(comm, dest, _comm), tag, comm);
248
#elif COMMUNICATION_MODE==1
Philipp Samfaß's avatar
Philipp Samfaß committed
249
  return smpi_send_offload_rma(buf, count, datatype, translateRank(comm, dest, _comm), tag, comm);
250
251
252
253
254
#endif
}

int MPI_Recv_offload(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status_Offload *stat) {
#if COMMUNICATION_MODE==0
Philipp Samfaß's avatar
Philipp Samfaß committed
255
  return smpi_recv_offload_p2p(buf, count, datatype, translateRank(comm, source, _comm), tag, comm, stat);
256
#elif COMMUNICATION_MODE==1
Philipp Samfaß's avatar
Philipp Samfaß committed
257
  //return smpi_recv_offload_rma(buf, count, datatype, translateRank(comm, source, _comm), tag, comm, stat);
258
  return smpi_recv_offload_rma(buf, count, datatype, translateRank(comm, source, _comm), tag, comm, stat);
259
260
261
262
263
#endif
}

#endif

264
int MPI_Barrier(MPI_Comm comm) {
265
  //assert(comm == MPI_COMM_WORLD);
266
  int err = synchroniseRanksInTeam();
267
268
269
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
270
271
int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
               MPI_Comm comm ) {
272
273
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Bcast(buffer, count, datatype, root, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
274
275
276
277
278
  return err;
}

int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
                  MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
279
280
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
281
282
283
284
285
286
  return err;
}

int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                 void *recvbuf, int recvcount, MPI_Datatype recvtype,
                 MPI_Comm comm) {
287
288
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
289
290
291
292
293
294
295
  return err;
}

int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
                  const int *sdispls, MPI_Datatype sendtype, void *recvbuf,
                  const int *recvcounts, const int *rdispls, MPI_Datatype recvtype,
                  MPI_Comm comm) {
296
297
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
298
299
300
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
301
double MPI_Wtime() {
Ben Hazelwood's avatar
Ben Hazelwood committed
302
  // If you mark on the timeline here expect negative time values (you've been warned)
303
  return PMPI_Wtime();
Ben Hazelwood's avatar
Ben Hazelwood committed
304
305
}

306
307
308
309
310
311
int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                int dest, int sendtag,
                void *recvbuf, int recvcount, MPI_Datatype recvtype,
                int source, int recvtag,
                MPI_Comm comm, MPI_Status *status) {
  if (comm == MPI_COMM_SELF) {
312
    if (sendcount == 0) {
313
      Timing::markTimeline(sendtag);
314
    } else {
315
      Timing::markTimeline(sendtag, sendbuf, sendcount, sendtype);
316
    }
317
  } else {
318
319
    //assert(comm == MPI_COMM_WORLD);
    MPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getTeamComm(comm),status);
320
321
322
323
  }
  return MPI_SUCCESS;
}

324
int MPI_Finalize() {
325
326
327
328
329
330
331
332
333
334
335
#ifdef USE_MPI_OFFLOADING
  if(!_is_server) {
#endif
    logInfo("Finalize");
    Timing::finaliseTiming();
    // Wait for all replicas before finalising
    PMPI_Barrier(TMPI_COMM_DUP);
    freeTeamComm();
    Timing::outputTiming();
#ifdef USE_MPI_OFFLOADING
  }
Philipp Samfaß's avatar
Philipp Samfaß committed
336
  smpi_finalize();
337
#endif
338
339
  return PMPI_Finalize();
}
Ben Hazelwood's avatar
Ben Hazelwood committed
340
341
342

int MPI_Abort(MPI_Comm comm, int errorcode) {
  assert(comm == MPI_COMM_WORLD);
343
  int err = PMPI_Abort(getTeamComm(comm), errorcode);
Ben Hazelwood's avatar
Ben Hazelwood committed
344
345
  return err;
}