Wrapper.cpp 6.2 KB
Newer Older
1
#include "Wrapper.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
2
3

#include <cassert>
4
#include <unistd.h>
Ben Hazelwood's avatar
Ben Hazelwood committed
5

6
#include "Logging.h"
Ben Hazelwood's avatar
Ben Hazelwood committed
7
#include "Rank.h"
8
9
10
#include "Timing.h"

int MPI_Init(int *argc, char*** argv) {
11
  int err = PMPI_Init(argc, argv);
Ben Hazelwood's avatar
Ben Hazelwood committed
12
  initialiseTMPI();
13
  return err;
14
15
16
}

int MPI_Init_thread( int *argc, char ***argv, int required, int *provided ) {
17
  int err = PMPI_Init_thread(argc, argv, required, provided);
Ben Hazelwood's avatar
Ben Hazelwood committed
18
  initialiseTMPI();
19
  return err;
20
21
22
23
}

int MPI_Is_thread_main(int* flag) {
  // See header documentation
Ben Hazelwood's avatar
Ben Hazelwood committed
24
  *flag = getTeam();
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  return MPI_SUCCESS;
}

int MPI_Comm_rank(MPI_Comm comm, int *rank) {
  assert(comm == MPI_COMM_WORLD);
  *rank = getTeamRank();
  return MPI_SUCCESS;
}

int MPI_Comm_size(MPI_Comm comm, int *size) {
  assert(comm == MPI_COMM_WORLD);
  *size = getTeamSize();
  return MPI_SUCCESS;
}

int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
              int tag, MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);
43
44
  int err = PMPI_Send(buf, count, datatype, dest, tag, getTeamComm());
  logInfo("Send to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
45
46
47
48
49
50
  return err;
}

int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
              MPI_Comm comm, MPI_Status *status) {
  assert(comm == MPI_COMM_WORLD);
51
52
  int err = PMPI_Recv(buf, count, datatype, source, tag, getTeamComm(), status);
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
53
54
55
56
57
58
  return err;
}

int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest,
               int tag, MPI_Comm comm, MPI_Request *request) {
  assert(comm == MPI_COMM_WORLD);
59
60
  int err = PMPI_Isend(buf, count, datatype, dest, tag, getTeamComm(), request);
  logInfo("Isend to rank " << dest << "/" << mapTeamToWorldRank(dest) << " with tag " << tag);
61
62
63
64
65
66
  return err;
}

int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
               MPI_Comm comm, MPI_Request *request) {
  assert(comm == MPI_COMM_WORLD);
67
68
  int err = PMPI_Irecv(buf, count, datatype, source, tag, getTeamComm(), request);
  logInfo("Receive from rank " << source << "/" << mapTeamToWorldRank(source) << " with tag " << tag);
69
70
71
72
73
  return err;
}

int MPI_Wait(MPI_Request *request, MPI_Status *status) {
  logInfo("Wait initialised");
74
75
  int err = PMPI_Wait(request, status);
  logInfo("Wait completed");
76
77
78
  return err;
}

79
80
int MPI_Waitall(int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[]) {
  logInfo("Waitall initialised with " << count << " requests");
81
  int err = PMPI_Waitall(count, array_of_requests, array_of_statuses);
82
83
84
85
  logInfo("Waitall completed with " << count << " requests");
  return err;
}

86
int MPI_Test(MPI_Request *request, int *flag, MPI_Status *status) {
87
88
  int err = PMPI_Test(request, flag, status);
  logInfo("Test completed (FLAG=" << *flag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
89
90
91
92
93
  return err;
}

int MPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status *status) {
  assert(comm == MPI_COMM_WORLD);
94
95
96
  logInfo("Probe initialised (SOURCE=" << source << ",TAG=" << tag << ")");
  int err = PMPI_Probe(source, tag, getTeamComm(), status);
  logInfo("Probe finished (SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
97
98
99
100
101
102
  return err;
}

int MPI_Iprobe(int source, int tag, MPI_Comm comm, int *flag,
                MPI_Status *status) {
  assert(comm == MPI_COMM_WORLD);
103
104
  int err = PMPI_Iprobe(source, tag, getTeamComm(), flag, status);
  logInfo("Iprobe finished (FLAG=" << *flag << ",SOURCE=" << mapTeamToWorldRank(source) << ",TAG=" << tag << ",STATUS_SOURCE=" << status->MPI_SOURCE << ",STATUS_TAG=" << status->MPI_TAG << ")");
105
106
107
108
109
  return err;
}

int MPI_Barrier(MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);
110
  int err = synchroniseRanksInTeam();
111
112
113
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
114
115
116
int MPI_Bcast( void *buffer, int count, MPI_Datatype datatype, int root,
               MPI_Comm comm ) {
  assert(comm == MPI_COMM_WORLD);
117
  int err = PMPI_Bcast(buffer, count, datatype, root, getTeamComm());
Ben Hazelwood's avatar
Ben Hazelwood committed
118
119
120
121
122
123
  return err;
}

int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
                  MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);
124
  int err = PMPI_Allreduce(sendbuf, recvbuf, count, datatype, op, getTeamComm());
Ben Hazelwood's avatar
Ben Hazelwood committed
125
126
127
128
129
130
131
  return err;
}

int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                 void *recvbuf, int recvcount, MPI_Datatype recvtype,
                 MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);
132
  int err = PMPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, getTeamComm());
Ben Hazelwood's avatar
Ben Hazelwood committed
133
134
135
136
137
138
139
140
  return err;
}

int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
                  const int *sdispls, MPI_Datatype sendtype, void *recvbuf,
                  const int *recvcounts, const int *rdispls, MPI_Datatype recvtype,
                  MPI_Comm comm) {
  assert(comm == MPI_COMM_WORLD);
141
  int err = PMPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, getTeamComm());
Ben Hazelwood's avatar
Ben Hazelwood committed
142
143
144
  return err;
}

Ben Hazelwood's avatar
Ben Hazelwood committed
145
double MPI_Wtime() {
Ben Hazelwood's avatar
Ben Hazelwood committed
146
  // If you mark on the timeline here expect negative time values (you've been warned)
147
  return PMPI_Wtime();
Ben Hazelwood's avatar
Ben Hazelwood committed
148
149
}

150
151
152
153
154
155
int MPI_Sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                int dest, int sendtag,
                void *recvbuf, int recvcount, MPI_Datatype recvtype,
                int source, int recvtag,
                MPI_Comm comm, MPI_Status *status) {
  if (comm == MPI_COMM_SELF) {
156
    if (sendcount == 0) {
157
      Timing::markTimeline(sendtag);
158
    } else {
159
      Timing::markTimeline(sendtag, sendbuf, sendcount, sendtype);
160
    }
161
162
  } else {
    assert(comm == MPI_COMM_WORLD);
Ben Hazelwood's avatar
Ben Hazelwood committed
163
    MPI_Sendrecv(sendbuf, sendcount, sendtype,dest,sendtag,recvbuf,recvcount,recvtype,source,recvtag, getTeamComm(),status);
164
165
166
167
  }
  return MPI_SUCCESS;
}

168
int MPI_Finalize() {
169
  logInfo("Finalize");
170
  Timing::finaliseTiming();
171
172
  // Wait for all replicas before finalising
  PMPI_Barrier(MPI_COMM_WORLD);
Ben Hazelwood's avatar
Ben Hazelwood committed
173
  freeTeamComm();
Ben Hazelwood's avatar
Ben Hazelwood committed
174
  Timing::outputTiming();
175
176
  return PMPI_Finalize();
}
Ben Hazelwood's avatar
Ben Hazelwood committed
177
178
179

int MPI_Abort(MPI_Comm comm, int errorcode) {
  assert(comm == MPI_COMM_WORLD);
180
  int err = PMPI_Abort(getTeamComm(), errorcode);
Ben Hazelwood's avatar
Ben Hazelwood committed
181
182
  return err;
}