@@ -22,13 +22,22 @@ class ViewBroadcast
22
22
using view_type = Unmanaged<ToView>;
23
23
using reference_type = typename view_type::reference_type;
24
24
25
+ // Broadcast an input view to the type provided by the class template arg
26
+ // - from_v: view to be broadcasted
27
+ // - extents: list of extents of the outfacing view (of type ToView).
28
+ // Must have rank=ToView::rank(). Must contain FromView::rank() entries
29
+ // that are <=0, which signal that those dimensions' extents will be
30
+ // retrieved from the input view.
25
31
template <typename FromView>
26
32
ViewBroadcast (const FromView& from_v,
27
33
const std::vector<int >& extents)
28
34
{
29
35
constexpr int from_rank = FromView::rank ();
30
36
constexpr int to_rank = ToView::rank ();
31
37
38
+ EKAT_REQUIRE_MSG (from_rank>=1 ,
39
+ " [ViewBroadcast] Error! FromView rank must be at least 1.\n "
40
+ " FromView::rank(): " << FromView::rank () << " \n " );
32
41
EKAT_REQUIRE_MSG (from_rank<=to_rank,
33
42
" [ViewBroadcast] Error! FromView rank exceeds ToView rank.\n "
34
43
" FromView::rank(): " << FromView::rank () << " \n "
@@ -41,46 +50,113 @@ class ViewBroadcast
41
50
" [ViewBroadcast] Badly sized extents vector.\n " );
42
51
43
52
// Init all dims as "ignored", then set up the ones actually picked
44
- typename ToView::traits::array_layout impl_layout,shape_layout ;
53
+ typename ToView::traits::array_layout impl_layout;
45
54
int ifrom = 0 ;
46
55
for (int i=0 ; i<to_rank; ++i) {
47
56
if (extents[i]<=0 ) {
48
57
EKAT_REQUIRE_MSG (ifrom<from_rank,
49
58
" Error! Too many missing extents in input vector.\n " );
50
59
impl_layout.dimension [i] = from_v.extent_int (ifrom);
51
- shape_layout. dimension [i] = from_v.extent_int (ifrom);
52
- coeff [i] = 1 ;
60
+ m_shape [i] = from_v.extent_int (ifrom);
61
+ m_coeff [i] = 1 ;
53
62
++ifrom;
54
63
} else {
55
64
impl_layout.dimension [i] = 1 ;
56
- shape_layout. dimension [i] = extents[i];
57
- coeff [i] = 0 ;
65
+ m_shape [i] = extents[i];
66
+ m_coeff [i] = 0 ;
58
67
}
59
68
}
60
69
EKAT_REQUIRE_MSG (ifrom==from_rank,
61
70
" Error! Too many positive extents in input vector.\n " );
62
71
63
72
m_view_impl = view_type (from_v.data (),impl_layout);
64
- m_view_shape = view_type (from_v.data (),shape_layout);
65
73
}
66
74
67
75
KOKKOS_INLINE_FUNCTION
68
- int extent (int i) { return m_view_shape.extent (i); }
76
+ int extent (int i) {
77
+ EKAT_KERNEL_ASSERT_MSG (i>=0 and i<=m_view_impl.rank (),
78
+ " [ViewBroadcast::extent] Error! Index out of bounds.\n " );
79
+ return m_shape[i];
80
+ }
81
+
82
+ // Rank 1
83
+ template <typename IntType>
84
+ KOKKOS_FORCEINLINE_FUNCTION
85
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
86
+ operator ()(IntType i0) const {
87
+ return m_view_impl (i0*m_coeff[0 ]);
88
+ }
89
+
90
+ // Rank 2
91
+ template <typename IntType>
92
+ KOKKOS_FORCEINLINE_FUNCTION
93
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
94
+ operator ()(IntType i0, IntType i1) const {
95
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ]);
96
+ }
97
+
98
+ // Rank 3
99
+ template <typename IntType>
100
+ KOKKOS_FORCEINLINE_FUNCTION
101
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
102
+ operator ()(IntType i0, IntType i1, IntType i2) const {
103
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ],i2*m_coeff[2 ]);
104
+ }
105
+
106
+ // Rank 4
107
+ template <typename IntType>
108
+ KOKKOS_FORCEINLINE_FUNCTION
109
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
110
+ operator ()(IntType i0, IntType i1, IntType i2, IntType i3) const {
111
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ],i2*m_coeff[2 ],i3*m_coeff[3 ]);
112
+ }
113
+
114
+ // Rank 5
115
+ template <typename IntType>
116
+ KOKKOS_FORCEINLINE_FUNCTION
117
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
118
+ operator ()(IntType i0, IntType i1, IntType i2, IntType i3,
119
+ IntType i4) const {
120
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ],i2*m_coeff[2 ],i3*m_coeff[3 ],
121
+ i4*m_coeff[4 ]);
122
+ }
123
+
124
+ // Rank 6
125
+ template <typename IntType>
126
+ KOKKOS_FORCEINLINE_FUNCTION
127
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
128
+ operator ()(IntType i0, IntType i1, IntType i2, IntType i3,
129
+ IntType i4, IntType i5) const {
130
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ],i2*m_coeff[2 ],i3*m_coeff[3 ],
131
+ i4*m_coeff[4 ],i5*m_coeff[5 ]);
132
+ }
133
+
134
+ // Rank 7
135
+ template <typename IntType>
136
+ KOKKOS_FORCEINLINE_FUNCTION
137
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
138
+ operator ()(IntType i0, IntType i1, IntType i2, IntType i3,
139
+ IntType i4, IntType i5, IntType i6) const {
140
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ],i2*m_coeff[2 ],i3*m_coeff[3 ],
141
+ i4*m_coeff[4 ],i5*m_coeff[5 ],i6*m_coeff[6 ]);
142
+ }
69
143
70
- template <typename ... Is>
144
+ // Rank 8
145
+ template <typename IntType>
71
146
KOKKOS_FORCEINLINE_FUNCTION
72
- reference_type operator ()(Is... indices) const {
73
- int i=0 ;
74
- ((indices *= coeff[i++]),...);
75
- return m_view_impl (indices...);
147
+ std::enable_if_t <std::is_integral_v<IntType>,reference_type>
148
+ operator ()(IntType i0, IntType i1, IntType i2, IntType i3,
149
+ IntType i4, IntType i5, IntType i6, IntType i7) const {
150
+ return m_view_impl (i0*m_coeff[0 ],i1*m_coeff[1 ],i2*m_coeff[2 ],i3*m_coeff[3 ],
151
+ i4*m_coeff[4 ],i5*m_coeff[5 ],i6*m_coeff[6 ],i7*m_coeff[7 ]);
76
152
}
77
153
78
154
protected:
79
155
80
156
view_type m_view_impl;
81
- view_type m_view_shape;
82
157
83
- int coeff[8 ];
158
+ int m_shape[8 ];
159
+ int m_coeff[8 ];
84
160
};
85
161
86
162
0 commit comments