Skip to content

Commit 0e6ce0c

Browse files
committed
Rework ViewBroadcast implementation
- Use int[8] rather than another view to store outfacing layout - Implement 8 separate method for operator(), like Kokkos::View does - Improve docs for the constructor
1 parent 04e0b8f commit 0e6ce0c

File tree

1 file changed

+90
-14
lines changed

1 file changed

+90
-14
lines changed

src/kokkos/ekat_view_broadcast.hpp

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,22 @@ class ViewBroadcast
2222
using view_type = Unmanaged<ToView>;
2323
using reference_type = typename view_type::reference_type;
2424

25+
// Broadcast an input view to the type provided by the class template arg
26+
// - from_v: view to be broadcasted
27+
// - extents: list of extents of the outfacing view (of type ToView).
28+
// Must have rank=ToView::rank(). Must contain FromView::rank() entries
29+
// that are <=0, which signal that those dimensions' extents will be
30+
// retrieved from the input view.
2531
template<typename FromView>
2632
ViewBroadcast (const FromView& from_v,
2733
const std::vector<int>& extents)
2834
{
2935
constexpr int from_rank = FromView::rank();
3036
constexpr int to_rank = ToView::rank();
3137

38+
EKAT_REQUIRE_MSG (from_rank>=1,
39+
"[ViewBroadcast] Error! FromView rank must be at least 1.\n"
40+
" FromView::rank(): " << FromView::rank() << "\n");
3241
EKAT_REQUIRE_MSG (from_rank<=to_rank,
3342
"[ViewBroadcast] Error! FromView rank exceeds ToView rank.\n"
3443
" FromView::rank(): " << FromView::rank() << "\n"
@@ -41,46 +50,113 @@ class ViewBroadcast
4150
"[ViewBroadcast] Badly sized extents vector.\n");
4251

4352
// Init all dims as "ignored", then set up the ones actually picked
44-
typename ToView::traits::array_layout impl_layout,shape_layout;
53+
typename ToView::traits::array_layout impl_layout;
4554
int ifrom = 0;
4655
for (int i=0; i<to_rank; ++i) {
4756
if (extents[i]<=0) {
4857
EKAT_REQUIRE_MSG (ifrom<from_rank,
4958
"Error! Too many missing extents in input vector.\n");
5059
impl_layout.dimension[i] = from_v.extent_int(ifrom);
51-
shape_layout.dimension[i] = from_v.extent_int(ifrom);
52-
coeff[i] = 1;
60+
m_shape[i] = from_v.extent_int(ifrom);
61+
m_coeff[i] = 1;
5362
++ifrom;
5463
} else {
5564
impl_layout.dimension[i] = 1;
56-
shape_layout.dimension[i] = extents[i];
57-
coeff[i] = 0;
65+
m_shape[i] = extents[i];
66+
m_coeff[i] = 0;
5867
}
5968
}
6069
EKAT_REQUIRE_MSG (ifrom==from_rank,
6170
"Error! Too many positive extents in input vector.\n");
6271

6372
m_view_impl = view_type(from_v.data(),impl_layout);
64-
m_view_shape = view_type(from_v.data(),shape_layout);
6573
}
6674

6775
KOKKOS_INLINE_FUNCTION
68-
int extent(int i) { return m_view_shape.extent(i); }
76+
int extent(int i) {
77+
EKAT_KERNEL_ASSERT_MSG (i>=0 and i<=m_view_impl.rank(),
78+
"[ViewBroadcast::extent] Error! Index out of bounds.\n");
79+
return m_shape[i];
80+
}
81+
82+
// Rank 1
83+
template<typename IntType>
84+
KOKKOS_FORCEINLINE_FUNCTION
85+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
86+
operator()(IntType i0) const {
87+
return m_view_impl(i0*m_coeff[0]);
88+
}
89+
90+
// Rank 2
91+
template<typename IntType>
92+
KOKKOS_FORCEINLINE_FUNCTION
93+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
94+
operator()(IntType i0, IntType i1) const {
95+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1]);
96+
}
97+
98+
// Rank 3
99+
template<typename IntType>
100+
KOKKOS_FORCEINLINE_FUNCTION
101+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
102+
operator()(IntType i0, IntType i1, IntType i2) const {
103+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2]);
104+
}
105+
106+
// Rank 4
107+
template<typename IntType>
108+
KOKKOS_FORCEINLINE_FUNCTION
109+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
110+
operator()(IntType i0, IntType i1, IntType i2, IntType i3) const {
111+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3]);
112+
}
113+
114+
// Rank 5
115+
template<typename IntType>
116+
KOKKOS_FORCEINLINE_FUNCTION
117+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
118+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
119+
IntType i4) const {
120+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
121+
i4*m_coeff[4]);
122+
}
123+
124+
// Rank 6
125+
template<typename IntType>
126+
KOKKOS_FORCEINLINE_FUNCTION
127+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
128+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
129+
IntType i4, IntType i5) const {
130+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
131+
i4*m_coeff[4],i5*m_coeff[5]);
132+
}
133+
134+
// Rank 7
135+
template<typename IntType>
136+
KOKKOS_FORCEINLINE_FUNCTION
137+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
138+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
139+
IntType i4, IntType i5, IntType i6) const {
140+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
141+
i4*m_coeff[4],i5*m_coeff[5],i6*m_coeff[6]);
142+
}
69143

70-
template<typename... Is>
144+
// Rank 8
145+
template<typename IntType>
71146
KOKKOS_FORCEINLINE_FUNCTION
72-
reference_type operator()(Is... indices) const {
73-
int i=0;
74-
((indices *= coeff[i++]),...);
75-
return m_view_impl(indices...);
147+
std::enable_if_t<std::is_integral_v<IntType>,reference_type>
148+
operator()(IntType i0, IntType i1, IntType i2, IntType i3,
149+
IntType i4, IntType i5, IntType i6, IntType i7) const {
150+
return m_view_impl(i0*m_coeff[0],i1*m_coeff[1],i2*m_coeff[2],i3*m_coeff[3],
151+
i4*m_coeff[4],i5*m_coeff[5],i6*m_coeff[6],i7*m_coeff[7]);
76152
}
77153

78154
protected:
79155

80156
view_type m_view_impl;
81-
view_type m_view_shape;
82157

83-
int coeff[8];
158+
int m_shape[8];
159+
int m_coeff[8];
84160
};
85161

86162

0 commit comments

Comments
 (0)