@@ -7,41 +7,41 @@ namespace kernel_float {
7
7
namespace detail {
8
8
9
9
template <typename ... Es>
10
- struct broadcast_extent_helper ;
10
+ struct broadcast_extent_impl ;
11
11
12
12
template <typename E>
13
- struct broadcast_extent_helper <E> {
13
+ struct broadcast_extent_impl <E> {
14
14
using type = E;
15
15
};
16
16
17
17
template <size_t N>
18
- struct broadcast_extent_helper <extent<N>, extent<N>> {
18
+ struct broadcast_extent_impl <extent<N>, extent<N>> {
19
19
using type = extent<N>;
20
20
};
21
21
22
22
template <size_t N>
23
- struct broadcast_extent_helper <extent<1 >, extent<N>> {
23
+ struct broadcast_extent_impl <extent<1 >, extent<N>> {
24
24
using type = extent<N>;
25
25
};
26
26
27
27
template <size_t N>
28
- struct broadcast_extent_helper <extent<N>, extent<1 >> {
28
+ struct broadcast_extent_impl <extent<N>, extent<1 >> {
29
29
using type = extent<N>;
30
30
};
31
31
32
32
template <>
33
- struct broadcast_extent_helper <extent<1 >, extent<1 >> {
33
+ struct broadcast_extent_impl <extent<1 >, extent<1 >> {
34
34
using type = extent<1 >;
35
35
};
36
36
37
37
template <typename A, typename B, typename C, typename ... Rest>
38
- struct broadcast_extent_helper <A, B, C, Rest...>:
39
- broadcast_extent_helper <typename broadcast_extent_helper <A, B>::type, C, Rest...> {};
38
+ struct broadcast_extent_impl <A, B, C, Rest...>:
39
+ broadcast_extent_impl <typename broadcast_extent_impl <A, B>::type, C, Rest...> {};
40
40
41
41
} // namespace detail
42
42
43
43
template <typename ... Es>
44
- using broadcast_extent = typename detail::broadcast_extent_helper <Es...>::type;
44
+ using broadcast_extent = typename detail::broadcast_extent_impl <Es...>::type;
45
45
46
46
template <typename ... Vs>
47
47
using broadcast_vector_extent_type = broadcast_extent<vector_extent_type<Vs>...>;
@@ -128,15 +128,24 @@ struct accurate_policy {};
128
128
* the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve
129
129
* approximations that slightly compromise precision.
130
130
*/
131
- struct fast_policy {};
131
+ struct fast_policy {
132
+ using fallback_policy = accurate_policy;
133
+ };
132
134
133
135
/* *
134
136
* This template policy allows developers to specify a custom degree of approximation for their computations. By
135
137
* adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the
136
138
* specific needs of your application. Higher values mean more precision.
137
139
*/
138
140
template <int Level = -1 >
139
- struct approx_level_policy {};
141
+ struct approx_level_policy {
142
+ using fallback_policy = approx_level_policy<>;
143
+ };
144
+
145
+ template <>
146
+ struct approx_level_policy <> {
147
+ using fallback_policy = fast_policy;
148
+ };
140
149
141
150
/* *
142
151
* The approximate_policy serves as the default approximation policy, providing a standard level of approximation
@@ -145,15 +154,17 @@ struct approx_level_policy {};
145
154
*/
146
155
using approx_policy = approx_level_policy<>;
147
156
148
- #ifndef KERNEL_FLOAT_POLICY
149
- #define KERNEL_FLOAT_POLICY accurate_policy
150
- #endif
151
-
152
157
/* *
153
158
* The `default_policy` acts as the standard computation policy. It can be configured externally using the
154
- * `KERNEL_FLOAT_POLICY ` macro. If `KERNEL_FLOAT_POLICY ` is not defined, it defaults to `accurate_policy`.
159
+ * `KERNEL_FLOAT_GLOBAL_POLICY ` macro. If `KERNEL_FLOAT_GLOBAL_POLICY ` is not defined, default to `accurate_policy`.
155
160
*/
161
+ #if defined(KERNEL_FLOAT_GLOBAL_POLICY)
162
+ using default_policy = KERNEL_FLOAT_GLOBAL_POLICY;
163
+ #elif defined(KERNEL_FLOAT_POLICY)
156
164
using default_policy = KERNEL_FLOAT_POLICY;
165
+ #else
166
+ using default_policy = accurate_policy;
167
+ #endif
157
168
158
169
namespace detail {
159
170
@@ -164,35 +175,15 @@ struct invoke_impl {
164
175
}
165
176
};
166
177
167
- //
168
178
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
169
- struct apply_fallback_impl {
170
- KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
171
- static_assert (N > 0 , " operation not implemented" );
172
- }
173
- };
179
+ struct apply_impl ;
174
180
175
181
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
176
- struct apply_base_impl : apply_fallback_impl< Policy, F, N, Output, Args...> {};
182
+ struct apply_base_impl : apply_impl< typename Policy::fallback_policy , F, N, Output, Args...> {};
177
183
178
184
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
179
185
struct apply_impl : apply_base_impl<Policy, F, N, Output, Args...> {};
180
186
181
- // `fast_policy` falls back to `accurate_policy`
182
- template <typename F, size_t N, typename Output, typename ... Args>
183
- struct apply_fallback_impl <fast_policy, F, N, Output, Args...>:
184
- apply_impl<accurate_policy, F, N, Output, Args...> {};
185
-
186
- // `approx_policy` falls back to `fast_policy`
187
- template <typename F, size_t N, typename Output, typename ... Args>
188
- struct apply_fallback_impl <approx_policy, F, N, Output, Args...>:
189
- apply_impl<fast_policy, F, N, Output, Args...> {};
190
-
191
- // `approx_level_policy` falls back to `approx_policy`
192
- template <int Level, typename F, size_t N, typename Output, typename ... Args>
193
- struct apply_fallback_impl <approx_level_policy<Level>, F, N, Output, Args...>:
194
- apply_impl<approx_policy, F, N, Output, Args...> {};
195
-
196
187
// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
197
188
template <typename F, size_t N, typename Output, typename ... Args>
198
189
struct apply_impl <accurate_policy, F, N, Output, Args...> {
@@ -266,4 +257,4 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
266
257
267
258
} // namespace kernel_float
268
259
269
- #endif // KERNEL_FLOAT_APPLY_H
260
+ #endif // KERNEL_FLOAT_APPLY_H
0 commit comments