Skip to content

Commit 9346539

Browse files
committed
feat: separate halo exchange logic into class
[skip-ci]
1 parent d8fd949 commit 9346539

File tree

2 files changed

+256
-171
lines changed

2 files changed

+256
-171
lines changed

core/src/include/Halo.hpp

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/*!
2+
* @file Halo.hpp
3+
*
4+
* @date 28 Jan 2025
5+
* @author Tom Meltzer <tdm39@cam.ac.uk>
6+
*/
7+
8+
#ifndef HALO_HPP
9+
#define HALO_HPP
10+
11+
#include <memory>
12+
#include <numeric>
13+
#include <vector>
14+
15+
#include "Slice.hpp"
16+
#include "include/ModelArray.hpp"
17+
#include "include/ModelArraySlice.hpp"
18+
#include "include/ModelMetadata.hpp"
19+
#include "include/dgVector.hpp"
20+
#include "include/indexer.hpp"
21+
#include "mpi.h"
22+
23+
#ifndef DGCOMP
24+
#define DGCOMP 3
25+
#endif
26+
27+
#ifndef DGSTRESSCOMP
28+
#define DGSTRESSCOMP 8
29+
#endif
30+
31+
#ifndef CGDEGREE
32+
#define CGDEGREE 2
33+
#endif
34+
35+
namespace Nextsim {
36+
37+
/*!
38+
* @brief A class to facilitate halo exchange between MPI ranks
39+
*
40+
* @details
41+
*/
42+
class Halo {
43+
public:
44+
/*!
45+
* @brief Constructs a halo object
46+
*/
47+
Halo(size_t localExtentX, size_t localExtentY, ModelMetadata& metadata)
48+
: m_localExtentX(localExtentX)
49+
, m_localExtentY(localExtentY)
50+
, m_metadata(std::make_unique<ModelMetadata>(metadata))
51+
, m_comm(metadata.mpiComm)
52+
, m_haloDims({ localExtentX + 2 * m_halo_width, localExtentY + 2 * m_halo_width })
53+
{
54+
m_perimeterLength = 2 * localExtentX + 2 * localExtentY;
55+
send.resize(m_perimeterLength, 0.0);
56+
recv.resize(m_perimeterLength, 0.0);
57+
m_edgeLengths
58+
= { localExtentX, localExtentY, localExtentX, localExtentY }; // order is Bottom
59+
// Right Top Left
60+
}
61+
62+
private:
63+
using Slice = ArraySlicer::Slice;
64+
using SliceIter = ArraySlicer::SliceIter;
65+
using Edge = ModelMetadata::Edge;
66+
using VBounds = ArraySlicer::Slice::VBounds;
67+
68+
const size_t m_halo_width = 1; // how many cells wide is the halo region
69+
const typedef ArraySlicer::SliceIter::MultiDim MultiDim;
70+
const MultiDim m_haloDims;
71+
size_t m_localExtentX; // local extent in x-direction
72+
size_t m_localExtentY; // local extent in y-direction
73+
size_t m_perimeterLength; // length of perimeter of domain
74+
std::unique_ptr<ModelMetadata> m_metadata; // pointer to metadata
75+
std::array<size_t, Edge::N_EDGE> m_edgeLengths; // array containing length of each edge
76+
std::array<Edge, Edge::N_EDGE> edges = ModelMetadata::edges; // array of edge enums
77+
std::map<Edge, Slice> m_slices = {
78+
{ Edge::LEFT, VBounds({ { 0 }, {} }) },
79+
{ Edge::RIGHT, VBounds({ { -1 }, {} }) },
80+
{ Edge::TOP, VBounds({ {}, { -1 } }) },
81+
{ Edge::BOTTOM, VBounds({ {}, { 0 } }) },
82+
};
83+
84+
MPI_Win m_win; // RMA memory window object (used for sharing send buffers between ranks)
85+
MPI_Comm m_comm; // RMA memory window object (used for sharing send buffers between ranks)
86+
87+
/*!
88+
* @brief Open memory window to exchange send buffer between MPI ranks.
89+
*
90+
* @ details this is not intended to be used manually. It should only be called as part of the
91+
* update method.
92+
*/
93+
void openMemoryWindow()
94+
{
95+
// create a RMA memory window which all ranks will be able to access
96+
MPI_Win_create(&send[0], m_perimeterLength * sizeof(double), sizeof(double), MPI_INFO_NULL,
97+
m_comm, &m_win);
98+
// remove fence and check that no proceding RMA calls have been made
99+
MPI_Win_fence(MPI_MODE_NOPRECEDE, m_win);
100+
}
101+
102+
/*!
103+
* @brief Initialise memory window to exchange send buffer between MPI ranks.
104+
*
105+
* @ details this is not intended to be used manually. It should only be called as part of the
106+
* update method.
107+
*/
108+
void closeMemoryWindow()
109+
{
110+
// enable fence i.e., disable future RMA calls until we re-open memory window
111+
MPI_Win_fence(MPI_MODE_NOSUCCEED, m_win);
112+
// free window object
113+
MPI_Win_free(&m_win);
114+
}
115+
116+
public:
117+
std::vector<double> send; // buffer to store halo region that will be read by other ranks
118+
std::vector<double> recv; // buffer to store halo region which is read from other ranks
119+
120+
/*!
121+
* @brief Populate send buffer with halo data of the specified ModelArray
122+
*
123+
* @params ma ModelArray which we intend to update across MPI ranks
124+
*/
125+
void populateSendBuffer(ModelArray& ma)
126+
{
127+
for (auto edge : ModelMetadata::edges) {
128+
size_t offset = std::accumulate(m_edgeLengths.begin(), m_edgeLengths.begin() + edge, 0);
129+
ma[m_slices.at(edge)].copyToBuffer(send, offset);
130+
}
131+
}
132+
133+
/*!
134+
* @brief Populate recv buffer with halo data from other ranks send buffers via the memory
135+
* window
136+
*/
137+
void populateRecvBuffer()
138+
{
139+
140+
// open memory window to send buffer on other ranks
141+
openMemoryWindow();
142+
143+
// get non-periodic neighbours and populate recv buffer (if the exist)
144+
for (auto edge : ModelMetadata::edges) {
145+
auto numNeighbours = m_metadata->neighbourRanks[edge].size();
146+
if (numNeighbours) {
147+
// get data for each neighbour that exists along a given edge
148+
for (size_t i = 0; i < numNeighbours; ++i) {
149+
int fromRank = m_metadata->neighbourRanks[edge][i];
150+
size_t count = m_metadata->neighbourExtents[edge][i];
151+
size_t disp = m_metadata->neighbourHaloSend[edge][i];
152+
size_t recvOffset = m_metadata->neighbourHaloRecv[edge][i];
153+
MPI_Get(&recv[recvOffset], count, MPI_DOUBLE, fromRank, disp, count, MPI_DOUBLE,
154+
m_win);
155+
}
156+
}
157+
}
158+
159+
// get periodic neighbours and populate recv buffer (if they exist)
160+
for (auto edge : ModelMetadata::edges) {
161+
auto numNeighbours = m_metadata->neighbourRanksPeriodic[edge].size();
162+
if (numNeighbours) {
163+
// get data for each neighbour that exists along a given edge
164+
for (size_t i = 0; i < numNeighbours; ++i) {
165+
int fromRank = m_metadata->neighbourRanksPeriodic[edge][i];
166+
size_t count = m_metadata->neighbourExtentsPeriodic[edge][i];
167+
size_t disp = m_metadata->neighbourHaloSendPeriodic[edge][i];
168+
size_t recvOffset = m_metadata->neighbourHaloRecvPeriodic[edge][i];
169+
MPI_Get(&recv[recvOffset], count, MPI_DOUBLE, fromRank, disp, count, MPI_DOUBLE,
170+
m_win);
171+
}
172+
}
173+
}
174+
175+
// close memory window (essentially make sure all communications are done before moving on)
176+
closeMemoryWindow();
177+
}
178+
179+
/*!
180+
* @brief Update a DGVector with data from the recv buffer
181+
*
182+
* @params dgvec DGVector which we intend to update across MPI ranks based on halo cells
183+
*/
184+
void updateDGVec(DGVector<DGCOMP>& dgvec)
185+
{
186+
for (auto edge : edges) {
187+
188+
SliceIter sIter = SliceIter(m_slices.at(edge), m_haloDims);
189+
std::vector<size_t> edgeIndices;
190+
191+
// populate edgeIndices with the indices along a given edge of the domain
192+
while (!sIter.isEnd()) {
193+
const size_t start = sIter.index();
194+
const size_t step = sIter.step(0);
195+
const size_t n = sIter.nElements(0);
196+
for (int i = 0; i < n; ++i) {
197+
auto idx = start + i * step;
198+
edgeIndices.push_back(idx);
199+
}
200+
sIter.incrementDim(1);
201+
}
202+
203+
// calculate offset index for the recv buffer based on current edge
204+
const size_t offset
205+
= std::accumulate(m_edgeLengths.begin(), m_edgeLengths.begin() + edge, 0);
206+
207+
// copy the halo region from recv buffer into the DGVector
208+
for (size_t i = 0; i < edgeIndices.size() - 2; ++i) {
209+
// note that the start index is offset by 1 and the loop limit is size() - 2 because
210+
// the edge of each domain is 2 less than the length of the expanded halo region
211+
// (see diagram below - the empty cells are skipped by going from i+1 to size()-2)
212+
// ┌─┬─┬─┬─┐
213+
// │ │x│x│ │
214+
// ├─┼─┼─┼─┤
215+
// │x│o│o│x│
216+
// ├─┼─┼─┼─┤
217+
// │x│o│o│x│ o = original data
218+
// ├─┼─┼─┼─┤ x = mpi halo data (from recv)
219+
// │ │x│x│ │ (empty) = unused data in DGVector
220+
// └─┴─┴─┴─┘
221+
dgvec(edgeIndices[i + 1], 0) = recv[offset + i];
222+
}
223+
}
224+
225+
recv.clear();
226+
send.clear();
227+
}
228+
};
229+
} // end of nextsim namespace
230+
231+
#endif /* HALO_HPP */

0 commit comments

Comments
 (0)