Skip to content

Commit 7c82fc7

Browse files
authored
Merge pull request #383 from E3SM-Project/bartgol/view-broadcast
Add ability to broadcast views to higher-dimensional arrays
2 parents f02ab6b + 0e6ce0c commit 7c82fc7

File tree

3 files changed

+402
-0
lines changed

3 files changed

+402
-0
lines changed

src/kokkos/ekat_view_broadcast.hpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#ifndef EKAT_BROADCAST_VIEW_HPP
2+
#define EKAT_BROADCAST_VIEW_HPP
3+
4+
#include "ekat_kokkos_meta.hpp"
5+
#include "ekat_assert.hpp"
6+
7+
namespace ekat {
8+
9+
/*
10+
* Broadcast a view into a higher dimensional view
11+
*
12+
* This allows to replicate (e.g. extrude) a view as a higher
13+
* dimensional one, without having to manually copy entries.
14+
* E.g., one can replicate a column vector v into a matrix A
15+
* with each column equal to v, so that A(i,j) = v(j) for every i.
16+
*/
17+
18+
template<typename ToView>
19+
class ViewBroadcast
20+
{
21+
public:
22+
using view_type = Unmanaged<ToView>;
23+
using reference_type = typename view_type::reference_type;
24+
25+
// Broadcast an input view to the type provided by the class template arg
26+
// - from_v: view to be broadcasted
27+
// - extents: list of extents of the outfacing view (of type ToView).
28+
// Must have rank=ToView::rank(). Must contain FromView::rank() entries
29+
// that are <=0, which signal that those dimensions' extents will be
30+
// retrieved from the input view.
31+
template<typename FromView>
32+
ViewBroadcast (const FromView& from_v,
33+
const std::vector<int>& extents)
34+
{
35+
constexpr int from_rank = FromView::rank();
36+
constexpr int to_rank = ToView::rank();
37+
38+
EKAT_REQUIRE_MSG (from_rank>=1,
39+
"[ViewBroadcast] Error! FromView rank must be at least 1.\n"
40+
" FromView::rank(): " << FromView::rank() << "\n");
41+
EKAT_REQUIRE_MSG (from_rank<=to_rank,
42+
"[ViewBroadcast] Error! FromView rank exceeds ToView rank.\n"
43+
" FromView::rank(): " << FromView::rank() << "\n"
44+
" ToView::rank(): " << ToView::rank() << "\n");
45+
EKAT_REQUIRE_MSG (to_rank<=8,
46+
"[ViewBroadcast] Error! ToView rank exceeds maximum rank (8).\n"
47+
" ToView::rank(): " << ToView::rank() << "\n");
48+
49+
EKAT_REQUIRE_MSG (extents.size()==to_rank,
50+
"[ViewBroadcast] Badly sized extents vector.\n");
51+
52+
// Init all dims as "ignored", then set up the ones actually picked
53+
typename ToView::traits::array_layout impl_layout;
54+
int ifrom = 0;
55+
for (int i=0; i<to_rank; ++i) {
56+
if (extents[i]<=0) {
57+
EKAT_REQUIRE_MSG (ifrom<from_rank,
58+
"Error! Too many missing extents in input vector.\n");
59+
impl_layout.dimension[i] = from_v.extent_int(ifrom);
60+
m_shape[i] = from_v.extent_int(ifrom);
61+
m_coeff[i] = 1;
62+
++ifrom;
63+
} else {
64+
impl_layout.dimension[i] = 1;
65+
m_shape[i] = extents[i];
66+
m_coeff[i] = 0;
67+
}
68+
}
69+
EKAT_REQUIRE_MSG (ifrom==from_rank,
70+
"Error! Too many positive extents in input vector.\n");
71+
72+
m_view_impl = view_type(from_v.data(),impl_layout);
73+
}
74+
75+
KOKKOS_INLINE_FUNCTION
76+
int extent(int i) {
77+
EKAT_KERNEL_ASSERT_MSG (i>=0 and i<=m_view_impl.rank(),
78+
"[ViewBroadcast::extent] Error! Index out of bounds.\n");
79+
return m_shape[i];
80+
}
81+
82+
// Rank 1
83+
template<typename IntType>
84+
KOKKOS_FORCEINLINE_FUNCTION
85+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
86+
operator()(IntType i0) const {
87+
return m_view_impl(i0*m_coeff[0]);
88+
}
89+
90+
// Rank 2
91+
template<typename IntType>
92+
KOKKOS_FORCEINLINE_FUNCTION
93+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
94+
operator()(IntType i0, IntType i1) const {
95+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1]);
96+
}
97+
98+
// Rank 3
99+
template<typename IntType>
100+
KOKKOS_FORCEINLINE_FUNCTION
101+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
102+
operator()(IntType i0, IntType i1, IntType i2) const {
103+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2]);
104+
}
105+
106+
// Rank 4
107+
template<typename IntType>
108+
KOKKOS_FORCEINLINE_FUNCTION
109+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
110+
operator()(IntType i0, IntType i1, IntType i2, IntType i3) const {
111+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3]);
112+
}
113+
114+
// Rank 5
115+
template<typename IntType>
116+
KOKKOS_FORCEINLINE_FUNCTION
117+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
118+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
119+
IntType i4) const {
120+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
121+
i4*m_coeff[4]);
122+
}
123+
124+
// Rank 6
125+
template<typename IntType>
126+
KOKKOS_FORCEINLINE_FUNCTION
127+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
128+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
129+
IntType i4, IntType i5) const {
130+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
131+
i4*m_coeff[4],i5*m_coeff[5]);
132+
}
133+
134+
// Rank 7
135+
template<typename IntType>
136+
KOKKOS_FORCEINLINE_FUNCTION
137+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
138+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
139+
IntType i4, IntType i5, IntType i6) const {
140+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
141+
i4*m_coeff[4],i5*m_coeff[5],i6*m_coeff[6]);
142+
}
143+
144+
// Rank 8
145+
template<typename IntType>
146+
KOKKOS_FORCEINLINE_FUNCTION
147+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
148+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
149+
IntType i4, IntType i5, IntType i6, IntType i7) const {
150+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
151+
i4*m_coeff[4],i5*m_coeff[5],i6*m_coeff[6],i7*m_coeff[7]);
152+
}
153+
154+
protected:
155+
156+
view_type m_view_impl;
157+
158+
int m_shape[8];
159+
int m_coeff[8];
160+
};
161+
162+
163+
template<typename ToView,typename FromView>
164+
ViewBroadcast<ToView> broadcast (const FromView& from,
165+
const std::vector<int>& extents)
166+
{
167+
return ViewBroadcast<ToView>(from,extents);
168+
}
169+
170+
} // namespace ekat
171+
172+
#endif // EKAT_BROADCAST_VIEW_HPP

tests/kokkos/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ EkatCreateUnitTest(math_utils math_utils.cpp
1212
EkatCreateUnitTest(view_utils view_utils.cpp
1313
LIBS ekat::KokkosUtils)
1414

15+
# Test view broadcast
16+
EkatCreateUnitTest(view_broadcast view_broadcast.cpp
17+
LIBS ekat::KokkosUtils)
18+
1519
# Test subview utils
1620
EkatCreateUnitTest(subview_utils subview_utils.cpp
1721
LIBS ekat::KokkosUtils)

0 commit comments

Comments
 (0)