Wrapper.cpp 7.61 KB
Newer Older
1
#include "Wrapper.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
2
3

#include <cassert>
4
#include <unistd.h>
Ben Hazelwood's avatar
Ben Hazelwood committed
5

6
#include "Logging.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
7
#include "RankOperations.h"
8
#include "Timing.h"
9
#include "TMPIConstants.h"
10
11

int MPI_Init(int *argc, char*** argv) {
12
13
14
  int err = 0;

  err |= PMPI_Init(argc, argv);
15
16
17

  init_rank();

18
  return err;
19
20
21
}

int MPI_Init_thread( int *argc, char ***argv, int required, int *provided ) {
22
23
24
  int err = 0;

  err |= PMPI_Init_thread(argc, argv, required, provided);
25
26
27

  init_rank();

28
  return err;
29
30
31
32
33
34
35
36
37
38
39
40
}

int MPI_Is_thread_main(int* flag) {
  // See header documentation
  *flag = get_R_number(getWorldRank()) + 1;
  return MPI_SUCCESS;
}

int MPI_Comm_rank(MPI_Comm comm, int *rank) {
  assert(comm == MPI_COMM_WORLD);

  *rank = getTeamRank();
41
//  logInfo("Returning rank " << *rank);
42
43
44
45
46
47
48
49

  return MPI_SUCCESS;
}

int MPI_Comm_size(MPI_Comm comm, int *size) {
  assert(comm == MPI_COMM_WORLD);

  *size = getTeamSize();
50
//  logInfo("Returning size " << *size);
51
52
53
54
55
56
57
58
59
60

  return MPI_SUCCESS;
}

int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
              int tag, MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;

61
  err |= PMPI_Send(buf, count, datatype, dest, tag, getReplicaCommunicator());
62

63
64
65
66
  logInfo(
      "Send to rank " <<
      dest <<
      "/" <<
67
      map_team_to_world(dest) <<
68
69
      " with tag " <<
      tag);
70
71
72
73
74
75
76
77
78
79

  return err;
}

int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
              MPI_Comm comm, MPI_Status *status) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;

80
  err |= PMPI_Recv(buf, count, datatype, source, tag, getReplicaCommunicator(), status);
81
82

  remap_status(status);
83
84
85
86
87

  logInfo(
      "Receive from rank " <<
      source <<
      "/" <<
88
      map_team_to_world(source) <<
89
90
      " with tag " <<
      tag);
91
92
93
94
95
96
97
98
99
100

  return err;
}

int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
               int tag, MPI_Comm comm, MPI_Request *request) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;

101
  err |= PMPI_Isend(buf, count, datatype, dest, tag, getReplicaCommunicator(), request);
102
//  Timing::startNonBlocking(Timing::NonBlockingType::iSend, tag, request);
103

104
105
106
107
  logInfo(
      "Isend to rank " <<
      dest <<
      "/" <<
108
      map_team_to_world(dest) <<
109
110
      " with tag " <<
      tag);
111
112
113
114
115
116
117
118
119
120

  return err;
}

int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
               MPI_Comm comm, MPI_Request *request) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;

121
  err |= PMPI_Irecv(buf, count, datatype, source, tag, getReplicaCommunicator(), request);
122

123
//  Timing::startNonBlocking(Timing::NonBlockingType::iRecv, tag, request);
124

125
126
127
128
  logInfo(
      "Receive from rank " <<
      source <<
      "/" <<
129
      map_team_to_world(source) <<
130
131
      " with tag " <<
      tag);
132
133
134
135
136
137
138
139
140
141

  return err;
}

int MPI_Wait(MPI_Request *request, MPI_Status *status) {
  int err = 0;
  logInfo("Wait initialised");

  err |= PMPI_Wait(request, status);

142
//  Timing::endNonBlocking(request, status);
143
144
145

  remap_status(status);
  logInfo("Wait completed "
146
147
//      <<"(STATUS_SOURCE=" << status->MPI_SOURCE
//      << ",STATUS_TAG=" << status->MPI_TAG
148
149
150
151
152
      << ")");

  return err;
}

153
154
155
156
157
158
159
int MPI_Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]) {
  int err = 0;

  logInfo("Waitall initialised with " << count << " requests");

  err |= PMPI_Waitall(count, array_of_requests, array_of_statuses);

160
161
162
163
  if (array_of_statuses != MPI_STATUSES_IGNORE) {
    for (int i = 0; i < count; i++) {
      remap_status(&array_of_statuses[i]);
    }
164
165
166
167
168
169
170
  }

  logInfo("Waitall completed with " << count << " requests");

  return err;
}

171
172
173
174
175
176
int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) {
  int err = 0;

  err |= PMPI_Test(request, flag, status);

  if (*flag) {
177
//    Timing::endNonBlocking(request, status);
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    remap_status(status);
  }

  logInfo("Test completed ("
      << "FLAG=" << *flag
      << ",STATUS_SOURCE=" << status->MPI_SOURCE
      << ",STATUS_TAG=" << status->MPI_TAG
      << ")");
  return err;
}

int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
  assert(comm == MPI_COMM_WORLD);
  int err = 0;

  logInfo(
      "Probe initialised (SOURCE="
      << source
      << ",TAG="
      << tag
      << ")");
199

200
  err |= PMPI_Probe(source, tag, getReplicaCommunicator(), status);
201
202
203
  remap_status(status);
  logInfo(
      "Probe finished ("
204
      << "SOURCE=" << map_team_to_world(source)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
      << ",TAG=" << tag
      << ",STATUS_SOURCE=" << status->MPI_SOURCE
      << ",STATUS_TAG=" << status->MPI_TAG
      << ")");

  return err;

}

int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
                MPI_Status *status) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;

220
  err |= PMPI_Iprobe(source, tag, getReplicaCommunicator(), flag, status);
221
222
223
224
  remap_status(status);
  logInfo(
      "Iprobe finished ("
      << "FLAG=" << *flag
225
      << ",SOURCE=" << map_team_to_world(source)
226
227
228
229
230
231
232
233
234
235
236
237
238
      << ",TAG=" << tag
      << ",STATUS_SOURCE=" << status->MPI_SOURCE
      << ",STATUS_TAG=" << status->MPI_TAG
      << ")");

  return err;
}

int MPI_Barrier(MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;

239
  err |= PMPI_Barrier(getReplicaCommunicator());
240
241
242
243

  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
244
245
246
247
248
int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
               MPI_Comm comm ) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;
249
  err |= PMPI_Bcast(buffer, count, datatype, root, getReplicaCommunicator());
Ben Hazelwood's avatar
Ben Hazelwood committed
250
251
252
253
254
255
256
257
258

  return err;
}

int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
                  MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;
259
  err |= PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getReplicaCommunicator());
Ben Hazelwood's avatar
Ben Hazelwood committed
260
261
262
263
264
265
266
267
268
269

  return err;
}

int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                 void *recvbuf, int recvcount, MPI_Datatype recvtype,
                 MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;
270
  err |= PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getReplicaCommunicator());
Ben Hazelwood's avatar
Ben Hazelwood committed
271
272
273
274
275
276
277
278
279
280
281

  return err;
}

int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
                  const int *sdispls, MPI_Datatype sendtype, void *recvbuf,
                  const int *recvcounts, const int *rdispls, MPI_Datatype recvtype,
                  MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;
282
  err |= PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getReplicaCommunicator());
Ben Hazelwood's avatar
Ben Hazelwood committed
283
284
285
286

  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
287
double MPI_Wtime() {
Ben Hazelwood's avatar
Ben Hazelwood committed
288
289
290
  double t = PMPI_Wtime();
  Timing::markTimeline(Timing::markType::Generic);
  return t;
Ben Hazelwood's avatar
Ben Hazelwood committed
291
292
}

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                int dest, int sendtag,
                void *recvbuf, int recvcount, MPI_Datatype recvtype,
                int source, int recvtag,
                MPI_Comm comm, MPI_Status *status) {
  if (comm == MPI_COMM_SELF) {
    Timing::markTimeline(Timing::markType::Generic);
    std::cout << "HERE\n";
  } else {
    assert(comm == MPI_COMM_WORLD);
    //@TODO: remap status?
    MPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getReplicaCommunicator(),status);
  }
  return MPI_SUCCESS;
}

309
int MPI_Finalize() {
310
  logInfo("Finalize");
311

Ben Hazelwood's avatar
Ben Hazelwood committed
312
313
  Timing::markTimeline(Timing::markType::Finalize);

314
315
316
  // Wait for all replicas before finalising
  PMPI_Barrier(MPI_COMM_WORLD);

317
  freeReplicaCommunicator();
Ben Hazelwood's avatar
Ben Hazelwood committed
318
  Timing::outputTiming();
319
320
321

  return PMPI_Finalize();
}
Ben Hazelwood's avatar
Ben Hazelwood committed
322
323
324
325
326

int MPI_Abort(MPI_Comm comm, int errorcode) {
  assert(comm == MPI_COMM_WORLD);

  int err = 0;
327
  err |= PMPI_Abort(getReplicaCommunicator(), errorcode);
Ben Hazelwood's avatar
Ben Hazelwood committed
328
329
330

  return err;
}