2.12.2021, 9:00 - 11:00: Due to updates GitLab may be unavailable for some minutes between 09:00 and 11:00.

Wrapper.cpp 11.6 KB
Newer Older
1
#include "Wrapper.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
2
3

#include <cassert>
4
#include <unistd.h>
Ben Hazelwood's avatar
Ben Hazelwood committed
5

6
#include "Logging.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
7
#include "Rank.h"
8
#include "Timing.h"
9
#include "CommStats.h"
10

11

12
#ifdef USE_SMARTMPI
13
14
15
16
#include "mpi_offloading.h"
#include "mpi_offloading_common.h"
#endif

17

18
int MPI_Init(int *argc, char*** argv) {
19
  int err = PMPI_Init(argc, argv);
20
#ifdef USE_SMARTMPI
Philipp Samfaß's avatar
Philipp Samfaß committed
21
  smpi_init( (*argv)[0]);
22
23
24
25
26
  if(!_is_server)
    initialiseTMPI(_comm);
#else
  initialiseTMPI(MPI_COMM_WORLD);
#endif
27
  return err;
28
29
30
}

int MPI_Init_thread( int *argc, char ***argv, int required, int *provided ) {
31
  int err = PMPI_Init_thread(argc, argv, required, provided);
32
#ifdef USE_SMARTMPI
Philipp Samfaß's avatar
Philipp Samfaß committed
33
  smpi_init( (*argv)[0]);
34
35
36
37
38
  if(!_is_server)
    initialiseTMPI(_comm);
#else
  initialiseTMPI(MPI_COMM_WORLD);
#endif
39
  return err;
40
41
42
43
}

int MPI_Is_thread_main(int* flag) {
  // See header documentation
Ben Hazelwood's avatar
Ben Hazelwood committed
44
  *flag = getTeam();
45
46
47
  return MPI_SUCCESS;
}

Philipp Samfaß's avatar
Philipp Samfaß committed
48
49
int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm *newcomm) {
  if(comm==MPI_COMM_WORLD)
50
#ifdef USE_SMARTMPI
Philipp Samfaß's avatar
Philipp Samfaß committed
51
    PMPI_Comm_split(_comm, color, key, newcomm);
52
53
54
#ifdef USE_SMARTMPI
    register_new_comm(*newcomm);
#endif
Philipp Samfaß's avatar
Philipp Samfaß committed
55
56
57
58
59
60
#else
    PMPI_Comm_split(comm, color, key, newcomm);
#endif
  return MPI_SUCCESS; 
}

61
int MPI_Comm_rank(MPI_Comm comm, int *rank) {
62
  // todo: assert that a team comm is used
63
64
  //assert(comm == MPI_COMM_WORLD);  

65
  if(comm==MPI_COMM_WORLD) 
66
#ifdef USE_SMARTMPI
67
68
69
70
71
72
73
74
    if(!_is_server) 
     *rank = getTeamRank();
    else {
     PMPI_Comm_rank(_comm, rank);
    }
#else
     *rank = getTeamRank();
#endif
75
76
  else 
   PMPI_Comm_rank(comm, rank);
77
78
79
80
  return MPI_SUCCESS;
}

int MPI_Comm_size(MPI_Comm comm, int *size) {
81
  //assert(comm == MPI_COMM_WORLD);
82
83
84
85
  *size = getTeamSize();
  return MPI_SUCCESS;
}

86
87
88
89
int MPI_Comm_dup(MPI_Comm comm, MPI_Comm *newcomm) {
  int err = PMPI_Comm_dup(getTeamComm(comm), newcomm);
  logInfo("Created communicator " << *newcomm);
  assert(err==MPI_SUCCESS);
90
91
92
#ifdef USE_SMARTMPI
  register_new_comm(*newcomm);
#endif
93
94
95
96
97
98
99
100
101
102
103
  return err;
}

int MPI_Comm_free(MPI_Comm *comm) {
  assert(*comm != MPI_COMM_WORLD);
  logInfo("Free communicator " << *comm);
  int err = PMPI_Comm_free(comm);
  assert(err==MPI_SUCCESS);
  return err;
}

104
105
int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
              int tag, MPI_Comm comm) {
106
#ifdef COMM_STATS
107
108
  CommunicationStatistics::trackSend(datatype,  count);
#endif
109
  //assert(comm == MPI_COMM_WORLD);
110
  int err;
111
#ifdef USE_SMARTMPI
112
113
114
115
116
117
118
119
120
121
  if(_is_server) {
    MPI_Comm mapped_comm = (comm==MPI_COMM_WORLD) ? _comm : comm;
    err = PMPI_Send(buf, count, datatype, dest, tag, mapped_comm);
  }
  else {
    err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm(comm));  
  }
#else
  err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm(comm));  
#endif
122
  logInfo("Send to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
123
124
125
126
127
  return err;
}

int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
              MPI_Comm comm, MPI_Status *status) {
128
  //assert(comm == MPI_COMM_WORLD);
129
  int err;
130
#ifdef USE_SMARTMPI 
131
132
133
134
135
136
137
138
  if(_is_server) {
    MPI_Comm mapped_comm = (comm==MPI_COMM_WORLD) ? _comm : comm;
    err = PMPI_Recv(buf, count, datatype, source, tag, mapped_comm, status);
  }
  else {
    err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(comm), status);     
  }
#else
139
#ifdef COMM_STATS
140
141
  CommunicationStatistics::trackReceive(datatype,  count);
#endif
142
143
   err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(comm), status);
#endif
144
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
145
146
147
  return err;
}

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                  void *recvbuf, int recvcount, MPI_Datatype recvtype,
                  MPI_Comm comm) {
  int err = PMPI_Allgather(sendbuf, sendcount, sendtype,
                           recvbuf, recvcount, recvtype,
                           getTeamComm(comm)); 
  return err;
}

int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                   void *recvbuf, int recvcount, MPI_Datatype recvtype,
                   MPI_Comm comm, MPI_Request * request) {
  int err = PMPI_Iallgather(sendbuf, sendcount, sendtype,
                            recvbuf, recvcount, recvtype,
                            getTeamComm(comm), request);
  return err;
}

166
167
int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
               int tag, MPI_Comm comm, MPI_Request *request) {
168
  //assert(comm == MPI_COMM_WORLD);
169
  int err;
170
#ifdef USE_SMARTMPI
171
172
173
174
175
176
177
178
  if(_is_server) {
    MPI_Comm mapped_comm = (comm==MPI_COMM_WORLD) ? _comm : comm;
    err = PMPI_Isend(buf, count, datatype, dest, tag, mapped_comm, request);
  }
  else {
    err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(comm), request);  
  }
#else
179
#ifdef COMM_STATS
180
181
  CommunicationStatistics::trackSend(datatype,  count);
#endif
182
  err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(comm), request);
183
  logInfo("Isend to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
184
#endif 
185
186
187
188
189
  return err;
}

int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
               MPI_Comm comm, MPI_Request *request) {
190
  //assert(comm == MPI_COMM_WORLD);
191
#ifdef COMM_STATS
192
193
  CommunicationStatistics::trackReceive(datatype,  count);
#endif
194
  int err = PMPI_Irecv(buf, count, datatype, source, tag, getTeamComm(comm), request);
195
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
196
197
198
199
200
  return err;
}

int MPI_Wait(MPI_Request *request, MPI_Status *status) {
  logInfo("Wait initialised");
201
202
  int err = PMPI_Wait(request, status);
  logInfo("Wait completed");
203
204
205
  return err;
}

206
207
int MPI_Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]) {
  logInfo("Waitall initialised with " << count << " requests");
208
  int err = PMPI_Waitall(count, array_of_requests, array_of_statuses);
209
210
211
212
  logInfo("Waitall completed with " << count << " requests");
  return err;
}

213
int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) {
214
215
  int err = PMPI_Test(request, flag, status);
  logInfo("Test completed (FLAG=" << *flag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
216
217
218
219
  return err;
}

int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
220
  //assert(comm == MPI_COMM_WORLD);
221
  logInfo("Probe initialised (SOURCE=" << source << ",TAG=" << tag << ")");
222
  int err = PMPI_Probe(source, tag, getTeamComm(comm), status);
223
  logInfo("Probe finished (SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
224
225
226
227
228
  return err;
}

int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
                MPI_Status *status) {
229
  //assert(comm == MPI_COMM_WORLD);
230
  int err;
231
#ifdef USE_SMARTMPI
232
233
234
235
236
237
238
239
240
241
242
243
  if(_is_server) {
    if(comm==MPI_COMM_WORLD)
      err = PMPI_Iprobe(source, tag, _comm, flag, status);
    else 
      err = PMPI_Iprobe(source, tag, comm, flag, status);      
  }
  else {
    err = PMPI_Iprobe(source, tag, getTeamComm(comm), flag, status);
  }
#else  
    err = PMPI_Iprobe(source, tag, getTeamComm(comm), flag, status);
#endif
244
  logInfo("Iprobe finished (FLAG=" << *flag << ",SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
245
246
247
  return err;
}

248
#ifdef USE_SMARTMPI
249
250
251
252
253

int MPI_Iprobe_offload(int source, int tag, MPI_Comm comm, int *flag, MPI_Status_Offload *status) {
  int ierr;
  assert(source==MPI_ANY_SOURCE);
#if COMMUNICATION_MODE==0
Philipp Samfaß's avatar
Philipp Samfaß committed
254
  ierr = smpi_iprobe_offload_p2p(source, tag, comm, flag, status);
Philipp Samfaß's avatar
Philipp Samfaß committed
255
#elif COMMUNICATION_MODE==2
Philipp Samfaß's avatar
Philipp Samfaß committed
256
257
  //ierr = smpi_iprobe_offload_rma(source, tag, comm, flag, status);
  ierr = smpi_iprobe_offload_p2p(source, tag, comm, flag, status);
258
#endif
Philipp Samfass's avatar
Philipp Samfass committed
259
260
  if(*flag)
    status->MPI_SOURCE = translateRank(MPI_COMM_WORLD, status->MPI_SOURCE, comm);
261
262
263
  return ierr;
}

264
int MPI_Send_offload(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, int rail) {
265
266

#if COMMUNICATION_MODE==0
267
  return smpi_send_offload_p2p(buf, count, datatype, translateRank(comm, dest, _comm), tag, comm, rail);
268
269
#elif COMMUNICATION_MODE==2
  return smpi_send_offload_gpi(buf, count, datatype, translateRank(comm, dest, _comm), tag, comm, rail);
270
271
272
#endif
}

273
int MPI_Recv_offload(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status_Offload *stat, int rail) {
274
#if COMMUNICATION_MODE==0
275
276
  int ierr = smpi_recv_offload_p2p(buf, count, datatype, translateRank(comm, source, _comm), tag, comm, stat, rail);
  return ierr;
277
#elif COMMUNICATION_MODE==2
Philipp Samfaß's avatar
Philipp Samfaß committed
278
  //return smpi_recv_offload_rma(buf, count, datatype, translateRank(comm, source, _comm), tag, comm, stat);
279
  return smpi_recv_offload_gpi(buf, count, datatype, translateRank(comm, source, _comm), tag, comm, stat, rail);
280
281
282
283
284
#endif
}

#endif

285
int MPI_Barrier(MPI_Comm comm) {
286
  //assert(comm == MPI_COMM_WORLD);
287
  int err = synchroniseRanksInTeam();
288
289
290
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
291
292
int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
               MPI_Comm comm ) {
293
294
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Bcast(buffer, count, datatype, root, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
295
296
297
298
299
  return err;
}

int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
                  MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
300
301
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
302
303
304
305
306
307
  return err;
}

int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                 void *recvbuf, int recvcount, MPI_Datatype recvtype,
                 MPI_Comm comm) {
308
309
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
310
311
312
313
314
315
316
  return err;
}

int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
                  const int *sdispls, MPI_Datatype sendtype, void *recvbuf,
                  const int *recvcounts, const int *rdispls, MPI_Datatype recvtype,
                  MPI_Comm comm) {
317
318
  //assert(comm == MPI_COMM_WORLD);
  int err = PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getTeamComm(comm));
Ben Hazelwood's avatar
Ben Hazelwood committed
319
320
321
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
322
double MPI_Wtime() {
Ben Hazelwood's avatar
Ben Hazelwood committed
323
  // If you mark on the timeline here expect negative time values (you've been warned)
324
  return PMPI_Wtime();
Ben Hazelwood's avatar
Ben Hazelwood committed
325
326
}

327
328
329
330
331
332
int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                int dest, int sendtag,
                void *recvbuf, int recvcount, MPI_Datatype recvtype,
                int source, int recvtag,
                MPI_Comm comm, MPI_Status *status) {
  if (comm == MPI_COMM_SELF) {
333
    if (sendcount == 0) {
334
      Timing::markTimeline(sendtag);
335
    } else {
336
      Timing::markTimeline(sendtag, sendbuf, sendcount, sendtype);
337
    }
338
  } else {
339
    //assert(comm == MPI_COMM_WORLD);
340
    PMPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getTeamComm(comm),status);
341
342
343
344
  }
  return MPI_SUCCESS;
}

345
int MPI_Finalize() {
346
#ifdef USE_SMARTMPI
347
348
349
350
351
352
353
354
  if(!_is_server) {
#endif
    logInfo("Finalize");
    Timing::finaliseTiming();
    // Wait for all replicas before finalising
    PMPI_Barrier(TMPI_COMM_DUP);
    freeTeamComm();
    Timing::outputTiming();
355
#ifdef USE_SMARTMPI
356
  }
357
358
  smpi_finalize();
#endif
359
#ifdef COMM_STATS
360
361
  CommunicationStatistics::outputCommunicationStatistics();
#endif
362
363
#ifdef DirtyCleanUp
  return MPI_SUCCESS;
364
#endif
Philipp Samfaß's avatar
Philipp Samfaß committed
365
  freeTeamInterComm();
366
367
  return PMPI_Finalize();
}
Ben Hazelwood's avatar
Ben Hazelwood committed
368
369
370

int MPI_Abort(MPI_Comm comm, int errorcode) {
  assert(comm == MPI_COMM_WORLD);
371
  int err = PMPI_Abort(getTeamComm(comm), errorcode);
Ben Hazelwood's avatar
Ben Hazelwood committed
372
373
  return err;
}